From 95da68d938534eaccce85b5143edd6edf6d26449 Mon Sep 17 00:00:00 2001 From: "yuxuan.wang1" Date: Fri, 13 Sep 2024 16:38:28 +0800 Subject: [PATCH] feat: optimize gRPC error handling --- pkg/kerrors/kerrors.go | 3 +- pkg/kerrors/kerrors_test.go | 12 +- pkg/kerrors/streaming_errors.go | 61 +++ pkg/kerrors/streaming_errors_test.go | 40 ++ pkg/remote/trans/nphttp2/conn_pool_test.go | 4 +- pkg/remote/trans/nphttp2/errors/errors.go | 77 ++++ .../trans/nphttp2/errors/errors_test.go | 64 ++++ pkg/remote/trans/nphttp2/grpc/controlbuf.go | 16 +- .../trans/nphttp2/grpc/controlbuf_test.go | 46 ++- .../trans/nphttp2/grpc/err_handling_test.go | 359 ++++++++++++++++++ pkg/remote/trans/nphttp2/grpc/http2_client.go | 124 ++++-- pkg/remote/trans/nphttp2/grpc/http2_server.go | 173 +++++---- pkg/remote/trans/nphttp2/grpc/http_util.go | 133 ++++++- .../trans/nphttp2/grpc/http_util_test.go | 192 ++++++++++ pkg/remote/trans/nphttp2/grpc/stream_test.go | 36 ++ pkg/remote/trans/nphttp2/grpc/transport.go | 122 +++++- .../trans/nphttp2/grpc/transport_test.go | 69 +++- pkg/remote/trans/nphttp2/server_handler.go | 1 + pkg/remote/trans/nphttp2/status/status.go | 41 +- .../trans/nphttp2/status/status_test.go | 33 +- pkg/streamx/provider/grpc/gerrors/gerrors.go | 62 +++ 21 files changed, 1483 insertions(+), 185 deletions(-) create mode 100644 pkg/kerrors/streaming_errors.go create mode 100644 pkg/kerrors/streaming_errors_test.go create mode 100644 pkg/remote/trans/nphttp2/errors/errors.go create mode 100644 pkg/remote/trans/nphttp2/errors/errors_test.go create mode 100644 pkg/remote/trans/nphttp2/grpc/err_handling_test.go create mode 100644 pkg/remote/trans/nphttp2/grpc/stream_test.go create mode 100644 pkg/streamx/provider/grpc/gerrors/gerrors.go diff --git a/pkg/kerrors/kerrors.go b/pkg/kerrors/kerrors.go index 783658f27e..f0468ba179 100644 --- a/pkg/kerrors/kerrors.go +++ b/pkg/kerrors/kerrors.go @@ -190,7 +190,8 @@ func IsKitexError(err error) bool { if _, ok := err.(*DetailedError); ok { return true } - return false + + return IsStreamingError(err) } // TimeoutCheckFunc is used to check whether the given err is a timeout error. diff --git a/pkg/kerrors/kerrors_test.go b/pkg/kerrors/kerrors_test.go index 1cd65cb257..fc580de1af 100644 --- a/pkg/kerrors/kerrors_test.go +++ b/pkg/kerrors/kerrors_test.go @@ -47,6 +47,16 @@ func TestIsKitexError(t *testing.T) { ErrNoMoreInstance, ErrConnOverLimit, ErrQPSOverLimit, + // streaming errors + ErrStreamingProtocol, + errStreamingTimeout, + ErrStreamTimeout, + ErrStreamingCanceled, + ErrBizCanceled, + ErrGracefulShutdown, + errStreamingMeta, + ErrMetaSizeExceeded, + ErrMetaContentIllegal, } for _, e := range errs { test.Assert(t, IsKitexError(e)) @@ -204,7 +214,7 @@ func TestFormat(t *testing.T) { error: errors.New("some_business_error"), } basicErr := &basicError{ - message: "fake_msg", + "fake_msg", } err := basicErr.WithCause(businessError) got := fmt.Sprintf("%+v", err) diff --git a/pkg/kerrors/streaming_errors.go b/pkg/kerrors/streaming_errors.go new file mode 100644 index 0000000000..da2ec6916f --- /dev/null +++ b/pkg/kerrors/streaming_errors.go @@ -0,0 +1,61 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kerrors + +import "errors" + +var ( + // ErrStreamingProtocol is the parent type of all streaming protocol(e.g. gRPC, TTHeader Streaming) + // related but not user-aware errors. + ErrStreamingProtocol = &basicError{"streaming protocol error"} + + // errStreamingTimeout is the parent type of all streaming timeout errors. + errStreamingTimeout = &basicError{"streaming timeout error"} + // ErrStreamTimeout denotes the timeout of the whole stream. + ErrStreamTimeout = errStreamingTimeout.WithCause(errors.New("stream timeout")) + + // ErrStreamingCanceled is the parent type of all streaming canceled errors. + ErrStreamingCanceled = &basicError{"streaming canceled error"} + // ErrBizCanceled denotes the stream is canceled by the biz code invoking cancel(). + ErrBizCanceled = ErrStreamingCanceled.WithCause(errors.New("business canceled")) + // ErrGracefulShutdown denotes the stream is canceled due to graceful shutdown. + ErrGracefulShutdown = ErrStreamingCanceled.WithCause(errors.New("graceful shutdown")) + + // errStreamingMeta is the parent type of all streaming meta errors. + errStreamingMeta = &basicError{"streaming meta error"} + // ErrMetaSizeExceeded denotes the streaming meta size exceeds the limit. + ErrMetaSizeExceeded = errStreamingMeta.WithCause(errors.New("meta size exceeds limit")) + // ErrMetaContentIllegal denotes the streaming meta content is illegal. + ErrMetaContentIllegal = errStreamingMeta.WithCause(errors.New("meta content illegal")) + + streamingBasicErrors = []*basicError{ + ErrStreamingProtocol, + errStreamingTimeout, + ErrStreamingCanceled, + errStreamingMeta, + } +) + +// IsStreamingError reports whether the given err is a streaming err +func IsStreamingError(err error) bool { + for _, sErr := range streamingBasicErrors { + if errors.Is(err, sErr) { + return true + } + } + return false +} diff --git a/pkg/kerrors/streaming_errors_test.go b/pkg/kerrors/streaming_errors_test.go new file mode 100644 index 0000000000..b530df0c76 --- /dev/null +++ b/pkg/kerrors/streaming_errors_test.go @@ -0,0 +1,40 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kerrors + +import ( + "testing" + + "github.com/cloudwego/kitex/internal/test" +) + +func TestIsStreamingError(t *testing.T) { + errs := []error{ + ErrStreamingProtocol, + errStreamingTimeout, + ErrStreamTimeout, + ErrStreamingCanceled, + ErrBizCanceled, + ErrGracefulShutdown, + errStreamingMeta, + ErrMetaSizeExceeded, + ErrMetaContentIllegal, + } + for _, err := range errs { + test.Assert(t, IsStreamingError(err), err) + } +} diff --git a/pkg/remote/trans/nphttp2/conn_pool_test.go b/pkg/remote/trans/nphttp2/conn_pool_test.go index 21d20867a7..bc5f093f03 100644 --- a/pkg/remote/trans/nphttp2/conn_pool_test.go +++ b/pkg/remote/trans/nphttp2/conn_pool_test.go @@ -24,6 +24,8 @@ import ( "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status" ) func TestConnPool(t *testing.T) { @@ -77,7 +79,7 @@ func TestReleaseConn(t *testing.T) { // close stream to ensure no active stream on this connection, // which will be released when put back to the connection pool and closed by GracefulClose s := conn.(*clientConn).s - conn.(*clientConn).tr.CloseStream(s, nil) + conn.(*clientConn).tr.CloseStream(s, status.Err(codes.Internal, "test")) test.Assert(t, err == nil, err) time.Sleep(100 * time.Millisecond) shortCP.Put(conn) diff --git a/pkg/remote/trans/nphttp2/errors/errors.go b/pkg/remote/trans/nphttp2/errors/errors.go new file mode 100644 index 0000000000..5a0cfa5efd --- /dev/null +++ b/pkg/remote/trans/nphttp2/errors/errors.go @@ -0,0 +1,77 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2021 CloudWeGo Authors. + */ + +package errors + +import ( + "errors" + "fmt" + + "github.com/cloudwego/kitex/pkg/kerrors" +) + +// This package contains all the errors suitable for Kitex errors model. +// These errors should not be used by user directly. +// If users need to perceive these errors, the pkg/streamx/provider/grpc/gerrors package should be used. +var ( + // stream error + ErrHTTP2Stream = newErrType("HTTP2Stream err when parsing HTTP2 frame") + ErrClosedWithoutTrailer = newErrType("client received Data frame with END_STREAM flag") + ErrMiddleHeader = newErrType("Headers frame appeared in the middle of a stream") + ErrDecodeHeader = newErrType("decoded Headers frame failed") + ErrRecvRstStream = newErrType("received RstStream frame") + ErrStreamDrain = newErrType("stream rejected by draining connection") + ErrStreamFlowControl = newErrType("stream-level flow control") + ErrIllegalHeaderWrite = newErrType("Headers frame has been already sent by server") + ErrStreamIsDone = newErrType("stream is done") + ErrMaxStreamExceeded = newErrType("max stream exceeded") + + // connection error + ErrHTTP2Connection = newErrType("HTTP2Connection err when parsing HTTP2 frame") + ErrEstablishConnection = newErrType("established connection failed") + ErrHandleGoAway = newErrType("handled GoAway Frame failed") + ErrKeepAlive = newErrType("keepalive failed") + ErrOperateHeaders = newErrType("operated Headers Frame failed") + ErrNoActiveStream = newErrType("no active stream") + ErrControlBufFinished = newErrType("controlbuf finished") + ErrNotReachable = newErrType("server transport is not reachable") + ErrConnectionIsClosing = newErrType("connection is closing") +) + +type errType struct { + message string + // parent errType + basic error +} + +func newErrType(message string) *errType { + return &errType{message: message, basic: kerrors.ErrStreamingProtocol} +} + +func (e *errType) Error() string { + if e.basic == nil { + return e.message + } + return fmt.Sprintf("%s - %s", e.basic.Error(), e.message) +} + +func (e *errType) Is(target error) bool { + return target == e || errors.Is(e.basic, target) +} diff --git a/pkg/remote/trans/nphttp2/errors/errors_test.go b/pkg/remote/trans/nphttp2/errors/errors_test.go new file mode 100644 index 0000000000..60724f2245 --- /dev/null +++ b/pkg/remote/trans/nphttp2/errors/errors_test.go @@ -0,0 +1,64 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2021 CloudWeGo Authors. + */ + +package errors + +import ( + "errors" + "strings" + "testing" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/kerrors" +) + +var errs = []*errType{ + // stream error + ErrHTTP2Stream, + ErrClosedWithoutTrailer, + ErrMiddleHeader, + ErrDecodeHeader, + ErrRecvRstStream, + ErrStreamDrain, + ErrStreamFlowControl, + ErrIllegalHeaderWrite, + ErrStreamIsDone, + ErrMaxStreamExceeded, + // connection error + ErrHTTP2Connection, + ErrEstablishConnection, + ErrHandleGoAway, + ErrKeepAlive, + ErrOperateHeaders, + ErrNoActiveStream, + ErrControlBufFinished, + ErrNotReachable, + ErrConnectionIsClosing, +} + +func TestErrType(t *testing.T) { + for _, err := range errs { + test.Assert(t, errors.Is(err, kerrors.ErrStreamingProtocol), err) + test.Assert(t, kerrors.IsKitexError(err), err) + test.Assert(t, kerrors.IsStreamingError(err), err) + test.Assert(t, strings.Contains(err.Error(), err.message), err) + test.Assert(t, strings.Contains(err.Error(), err.basic.Error()), err) + } +} diff --git a/pkg/remote/trans/nphttp2/grpc/controlbuf.go b/pkg/remote/trans/nphttp2/grpc/controlbuf.go index 7dd4701124..809126614f 100644 --- a/pkg/remote/trans/nphttp2/grpc/controlbuf.go +++ b/pkg/remote/trans/nphttp2/grpc/controlbuf.go @@ -451,19 +451,23 @@ func (c *controlBuffer) get(block bool) (interface{}, error) { select { case <-c.ch: case <-c.done: - c.finish() - return nil, ErrConnClosing + var err error + c.finish(errStatusControlBufFinished) + c.mu.Lock() + err = c.err + c.mu.Unlock() + return nil, err } } } -func (c *controlBuffer) finish() { +func (c *controlBuffer) finish(err error) { c.mu.Lock() if c.err != nil { c.mu.Unlock() return } - c.err = ErrConnClosing + c.err = err // There may be headers for streams in the control buffer. // These streams need to be cleaned out since the transport // is still not aware of these yet. @@ -473,7 +477,7 @@ func (c *controlBuffer) finish() { continue } if hdr.onOrphaned != nil { // It will be nil on the server-side. - hdr.onOrphaned(ErrConnClosing) + hdr.onOrphaned(err) } } c.mu.Unlock() @@ -696,7 +700,7 @@ func (l *loopyWriter) originateStream(str *outStream) error { if err == ErrConnClosing { return err } - // Other errors(errStreamDrain) need not close transport. + // Other errors(errStatusStreamDrain) need not close transport. return nil } if err := l.writeHeader(str.id, hdr.endStream, hdr.hf, hdr.onWrite); err != nil { diff --git a/pkg/remote/trans/nphttp2/grpc/controlbuf_test.go b/pkg/remote/trans/nphttp2/grpc/controlbuf_test.go index 643a4e3de5..d69e660f94 100644 --- a/pkg/remote/trans/nphttp2/grpc/controlbuf_test.go +++ b/pkg/remote/trans/nphttp2/grpc/controlbuf_test.go @@ -18,6 +18,7 @@ package grpc import ( "context" + "errors" "testing" "time" @@ -25,7 +26,7 @@ import ( ) func TestControlBuf(t *testing.T) { - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) cb := newControlBuffer(ctx.Done()) // test put() @@ -52,7 +53,8 @@ func TestControlBuf(t *testing.T) { test.Assert(t, !success, err) // test throttle() mock a lot of response frame so throttle() will block current goroutine - for i := 0; i < maxQueuedTransportResponseFrames+5; i++ { + exceedSize := 5 + for i := 0; i < maxQueuedTransportResponseFrames+exceedSize; i++ { err := cb.put(&ping{}) test.Assert(t, err == nil, err) } @@ -60,16 +62,44 @@ func TestControlBuf(t *testing.T) { // start a new goroutine to consume response frame go func() { time.Sleep(time.Millisecond * 100) - for { + for i := 0; i < exceedSize+1; i++ { it, err := cb.get(false) - if err != nil || it == nil { - break - } + test.Assert(t, err == nil, err) + test.Assert(t, it != nil) } }() cb.throttle() + // consumes all of the frames + for { + it, err := cb.get(false) + if err != nil || it == nil { + break + } + } + + finishErr := errors.New("finish") + go func() { + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + for range ticker.C { + var block bool + cb.mu.Lock() + block = cb.consumerWaiting + cb.mu.Unlock() + if block { + cb.finish(finishErr) + cancel() + return + } + } + }() + item, err = cb.get(true) + test.Assert(t, err == finishErr, err) + test.Assert(t, item == nil, item) - // test finish() - cb.finish() + err = cb.put(testItem) + test.Assert(t, err == finishErr, err) + _, err = cb.get(false) + test.Assert(t, err == finishErr, err) } diff --git a/pkg/remote/trans/nphttp2/grpc/err_handling_test.go b/pkg/remote/trans/nphttp2/grpc/err_handling_test.go new file mode 100644 index 0000000000..dc4769e94c --- /dev/null +++ b/pkg/remote/trans/nphttp2/grpc/err_handling_test.go @@ -0,0 +1,359 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2021 CloudWeGo Authors. + */ + +package grpc + +import ( + "context" + "errors" + "io" + "net" + "sync" + "testing" + "time" + + "github.com/cloudwego/netpoll" + "golang.org/x/net/http2/hpack" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" + gerrors "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/errors" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status" +) + +const ( + testTypeKey string = "testType" + errBizCanceled string = "errBizCanceled" + errStreamTimeout string = "errStreamTimeout" + errMiddleHeader string = "errMiddleHeader" + errDecodeHeader string = "errDecodeHeader" + errHTTP2Stream string = "errHTTP2Stream" + errClosedWithoutTrailer string = "errClosedWithoutTrailer" + errRecvRstStream string = "errRecvRstStream" +) + +type expectedErrs map[string]struct { + cliErr error + srvErr error +} + +func (errs expectedErrs) getClientExpectedErr(testType string) error { + return errs[testType].cliErr +} + +func (errs expectedErrs) getServerExpectedErr(testType string) error { + return errs[testType].srvErr +} + +func TestErrorHandling(t *testing.T) { + t.Run("close stream", func(t *testing.T) { + testcases := []struct { + desc string + setup func(t *testing.T) + clean func(t *testing.T) + customRstCodeMapping map[uint32]error + errs expectedErrs + }{ + { + desc: "normal RstCode", + errs: expectedErrs{ + errBizCanceled: {kerrors.ErrBizCanceled, gerrors.ErrRecvRstStream}, + errStreamTimeout: {kerrors.ErrStreamTimeout, kerrors.ErrStreamTimeout}, + errMiddleHeader: {gerrors.ErrMiddleHeader, gerrors.ErrRecvRstStream}, + errDecodeHeader: {gerrors.ErrDecodeHeader, gerrors.ErrRecvRstStream}, + errHTTP2Stream: {gerrors.ErrHTTP2Stream, gerrors.ErrRecvRstStream}, + errClosedWithoutTrailer: {gerrors.ErrClosedWithoutTrailer, nil}, + errRecvRstStream: {gerrors.ErrRecvRstStream, nil}, + }, + }, + { + desc: "custom RstCode", + setup: func(t *testing.T) { + SetCustomRstCodeEnabled(true) + RegisterCustomRstCode(1000, kerrors.ErrGracefulShutdown) + RegisterCustomRstCode(1001, kerrors.ErrBizCanceled) + RegisterCustomRstCode(1002, kerrors.ErrStreamingCanceled) + RegisterCustomRstCode(1003, kerrors.ErrStreamTimeout) + RegisterCustomRstCode(1004, gerrors.ErrMiddleHeader) + RegisterCustomRstCode(1005, gerrors.ErrDecodeHeader) + RegisterCustomRstCode(1006, gerrors.ErrHTTP2Stream) + RegisterCustomRstCode(1007, gerrors.ErrClosedWithoutTrailer) + }, + clean: func(t *testing.T) { + SetCustomRstCodeEnabled(false) + }, + errs: expectedErrs{ + errBizCanceled: {kerrors.ErrBizCanceled, kerrors.ErrBizCanceled}, + errStreamTimeout: {kerrors.ErrStreamTimeout, kerrors.ErrStreamTimeout}, + errMiddleHeader: {gerrors.ErrMiddleHeader, gerrors.ErrMiddleHeader}, + errDecodeHeader: {gerrors.ErrDecodeHeader, gerrors.ErrDecodeHeader}, + errHTTP2Stream: {gerrors.ErrHTTP2Stream, gerrors.ErrHTTP2Stream}, + errClosedWithoutTrailer: {gerrors.ErrClosedWithoutTrailer, nil}, + errRecvRstStream: {kerrors.ErrGracefulShutdown, nil}, + }, + }, + } + + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + if tc.setup != nil { + tc.setup(t) + } + defer func() { + if tc.clean != nil { + tc.clean(t) + } + }() + lis, err := netpoll.CreateListener("tcp", "localhost:0") + test.Assert(t, err == nil, err) + _, port, err := net.SplitHostPort(lis.Addr().String()) + test.Assert(t, err == nil, err) + cfg := &ServerConfig{} + var wg sync.WaitGroup + var onConnect netpoll.OnConnect = func(ctx context.Context, conn netpoll.Connection) context.Context { + rawSrv, err := newHTTP2Server(ctx, conn, cfg) + srv := rawSrv.(*http2Server) + test.Assert(t, err == nil, err) + srv.HandleStreams(func(stream *Stream) { + wg.Add(1) + go func() { + defer wg.Done() + md, ok := metadata.FromIncomingContext(stream.Context()) + test.Assert(t, ok) + vals := md.Get(testTypeKey) + test.Assert(t, len(vals) == 1, md) + testType := vals[0] + expectedErr := tc.errs.getServerExpectedErr(testType) + switch testType { + case errMiddleHeader: + errMiddleHeaderHandler(t, srv, stream, expectedErr) + case errDecodeHeader: + errDecodeHeaderHandler(t, srv, stream, expectedErr) + case errHTTP2Stream: + errHTTP2StreamHandler(t, srv, stream, expectedErr) + case errBizCanceled: + errBizCanceledHandler(t, srv, stream, expectedErr) + case errStreamTimeout: + errStreamTimeoutHandler(t, srv, stream, expectedErr) + case errClosedWithoutTrailer: + errClosedWithoutTrailerHandler(t, srv, stream, expectedErr) + case errRecvRstStream: + errRecvRstStreamHandler(t, srv, stream, expectedErr) + } + }() + }, func(ctx context.Context, s string) context.Context { + return ctx + }) + return nil + } + eventloop, err := netpoll.NewEventLoop(nil, + netpoll.WithOnConnect(onConnect), + netpoll.WithIdleTimeout(10*time.Second), + ) + test.Assert(t, err == nil, err) + go func() { + eventloop.Serve(lis) + }() + // create http2Client + conn, dErr := netpoll.NewDialer().DialTimeout("tcp", "localhost:"+port, time.Second) + test.Assert(t, dErr == nil, dErr) + cli, cErr := newHTTP2Client(context.Background(), conn.(netpoll.Connection), ConnectOptions{}, "", func(GoAwayReason) {}, func() {}) + test.Assert(t, cErr == nil, cErr) + defer func() { + wg.Wait() + cli.Close(status.Err(codes.Internal, "test")) + eventloop.Shutdown(context.Background()) + }() + callHdr := &CallHdr{ + Host: "host", + Method: "method", + } + buf := make([]byte, 1) + t.Run("Headers Frame appeared in the middle of the stream", func(t *testing.T) { + testType := errMiddleHeader + expectedErr := tc.errs.getClientExpectedErr(testType) + ctx := metadata.AppendToOutgoingContext(context.Background(), testTypeKey, testType) + stream, err := cli.NewStream(ctx, callHdr) + test.Assert(t, err == nil, err) + _, err = stream.Header() + test.Assert(t, err == nil, err) + _, recvErr := stream.Read(buf) + test.Assert(t, errors.Is(recvErr, expectedErr), recvErr) + }) + t.Run("Decode Headers Frame failed", func(t *testing.T) { + testType := errDecodeHeader + expectedErr := tc.errs.getClientExpectedErr(testType) + ctx := metadata.AppendToOutgoingContext(context.Background(), testTypeKey, errDecodeHeader) + stream, err := cli.NewStream(ctx, callHdr) + test.Assert(t, err == nil, err) + _, recvErr := stream.Read(buf) + test.Assert(t, errors.Is(recvErr, expectedErr), recvErr) + }) + t.Run("HTTP2Stream err when parsing frame", func(t *testing.T) { + testType := errMiddleHeader + expectedErr := tc.errs.getClientExpectedErr(testType) + ctx := metadata.AppendToOutgoingContext(context.Background(), testTypeKey, testType) + stream, err := cli.NewStream(ctx, callHdr) + test.Assert(t, err == nil, err) + _, recvErr := stream.Read(buf) + test.Assert(t, errors.Is(recvErr, expectedErr), recvErr) + }) + t.Run("Biz context canceled", func(t *testing.T) { + testType := errBizCanceled + expectedErr := tc.errs.getClientExpectedErr(testType) + ctx := metadata.AppendToOutgoingContext(context.Background(), testTypeKey, testType) + ctx, cancel := context.WithCancel(ctx) + stream, err := cli.NewStream(ctx, callHdr) + test.Assert(t, err == nil, err) + _, err = stream.Header() + test.Assert(t, err == nil, err) + cancel() + _, recvErr := stream.Read(buf) + test.Assert(t, errors.Is(recvErr, expectedErr), recvErr) + }) + t.Run("Timeout context canceled", func(t *testing.T) { + testType := errStreamTimeout + expectedErr := tc.errs.getClientExpectedErr(testType) + ctx := metadata.AppendToOutgoingContext(context.Background(), testTypeKey, testType) + ctx, cancel := context.WithTimeout(ctx, time.Millisecond*10) + defer cancel() + stream, err := cli.NewStream(ctx, callHdr) + test.Assert(t, err == nil, err) + _, err = stream.Header() + test.Assert(t, err == nil, err) + _, recvErr := stream.Read(buf) + test.Assert(t, errors.Is(recvErr, expectedErr), recvErr) + }) + t.Run("Stream closed without trailer frame", func(t *testing.T) { + testType := errMiddleHeader + expectedErr := tc.errs.getClientExpectedErr(testType) + ctx := metadata.AppendToOutgoingContext(context.Background(), testTypeKey, testType) + stream, err := cli.NewStream(ctx, callHdr) + test.Assert(t, err == nil, err) + _, err = stream.Header() + test.Assert(t, err == nil, err) + _, recvErr := stream.Read(buf) + if errors.Is(recvErr, io.EOF) { + recvErr = stream.Status().Err() + } + test.Assert(t, errors.Is(recvErr, expectedErr), recvErr) + }) + t.Run("Receive RstStream Frame", func(t *testing.T) { + testType := errRecvRstStream + expectedErr := tc.errs.getClientExpectedErr(testType) + ctx := metadata.AppendToOutgoingContext(context.Background(), testTypeKey, testType) + stream, err := cli.NewStream(ctx, callHdr) + test.Assert(t, err == nil, err) + _, err = stream.Header() + test.Assert(t, err == nil, err) + _, recvErr := stream.Read(buf) + if errors.Is(recvErr, io.EOF) { + recvErr = stream.Status().Err() + } + test.Assert(t, errors.Is(recvErr, expectedErr), recvErr) + }) + }) + } + }) +} + +func errMiddleHeaderHandler(t *testing.T, srv *http2Server, stream *Stream, expectedErr error) { + var err error + buf := make([]byte, 5) + err = stream.SendHeader(nil) + test.Assert(t, err == nil, err) + err = srv.controlBuf.put(&headerFrame{ + streamID: stream.id, + endStream: false, + }) + test.Assert(t, err == nil, err) + _, recvErr := stream.Read(buf) + test.Assert(t, errors.Is(recvErr, expectedErr), recvErr) +} + +func errDecodeHeaderHandler(t *testing.T, srv *http2Server, stream *Stream, expectedErr error) { + var err error + buf := make([]byte, 5) + err = srv.controlBuf.put(&headerFrame{ + streamID: stream.id, + }) + test.Assert(t, err == nil, err) + _, recvErr := stream.Read(buf) + test.Assert(t, errors.Is(recvErr, expectedErr), recvErr) +} + +func errHTTP2StreamHandler(t *testing.T, srv *http2Server, stream *Stream, expectedErr error) { + var err error + buf := make([]byte, 5) + err = srv.controlBuf.put(&headerFrame{ + streamID: stream.id, + // regular header field is previous to pseudo header field + hf: []hpack.HeaderField{ + {Name: "key", Value: "val"}, + {Name: ":status", Value: "200"}, + }, + }) + test.Assert(t, err == nil, err) + _, recvErr := stream.Read(buf) + test.Assert(t, errors.Is(recvErr, expectedErr), recvErr) +} + +func errBizCanceledHandler(t *testing.T, srv *http2Server, stream *Stream, expectedErr error) { + var err error + buf := make([]byte, 5) + err = stream.SendHeader(nil) + test.Assert(t, err == nil, err) + _, recvErr := stream.Read(buf) + test.Assert(t, errors.Is(recvErr, expectedErr), recvErr) +} + +func errStreamTimeoutHandler(t *testing.T, srv *http2Server, stream *Stream, expectedErr error) { + var err error + buf := make([]byte, 5) + err = stream.SendHeader(nil) + test.Assert(t, err == nil, err) + _, recvErr := stream.Read(buf) + test.Assert(t, errors.Is(recvErr, expectedErr), recvErr) +} + +func errClosedWithoutTrailerHandler(t *testing.T, srv *http2Server, stream *Stream, expectedErr error) { + var err error + err = stream.SendHeader(nil) + test.Assert(t, err == nil, err) + err = srv.controlBuf.put(&dataFrame{ + streamID: stream.id, + endStream: true, + }) + test.Assert(t, err == nil, err) +} + +func errRecvRstStreamHandler(t *testing.T, srv *http2Server, stream *Stream, expectedErr error) { + err := stream.SendHeader(nil) + test.Assert(t, err == nil, err) + err = srv.controlBuf.put(&cleanupStream{ + streamID: stream.id, + rst: true, + rstCode: getRstCode(status.New(codes.Unavailable, "test").WithMappingErr(kerrors.ErrGracefulShutdown).Err()), + onWrite: func() {}, + }) + test.Assert(t, err == nil, err) +} diff --git a/pkg/remote/trans/nphttp2/grpc/http2_client.go b/pkg/remote/trans/nphttp2/grpc/http2_client.go index 58b2eaba01..f004b98559 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_client.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_client.go @@ -39,6 +39,7 @@ import ( "github.com/cloudwego/kitex/pkg/gofunc" "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" + gerrors "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/errors" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc/grpcframe" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc/syscall" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" @@ -219,12 +220,14 @@ func newHTTP2Client(ctx context.Context, conn net.Conn, opts ConnectOptions, // Send connection preface to server. n, err := t.conn.Write(ClientPreface) if err != nil { - err = connectionErrorf(true, err, "transport: failed to write client preface: %v", err) + err = status.Newf(codes.Unavailable, "transport: failed to write client preface: %v", err). + WithMappingErr(gerrors.ErrEstablishConnection).Err() t.Close(err) return nil, err } if n != ClientPrefaceLen { - err = connectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, ClientPrefaceLen) + err = status.Newf(codes.Unavailable, "transport: preface mismatch, wrote %d bytes; want %d", n, ClientPrefaceLen). + WithMappingErr(gerrors.ErrEstablishConnection).Err() t.Close(err) return nil, err } @@ -237,7 +240,8 @@ func newHTTP2Client(ctx context.Context, conn net.Conn, opts ConnectOptions, } err = t.framer.WriteSettings(ss...) if err != nil { - err = connectionErrorf(true, err, "transport: failed to write initial settings frame: %v", err) + err = status.Newf(codes.Unavailable, "transport: failed to write initial settings frame: %v", err). + WithMappingErr(gerrors.ErrEstablishConnection).Err() t.Close(err) return nil, err } @@ -245,7 +249,8 @@ func newHTTP2Client(ctx context.Context, conn net.Conn, opts ConnectOptions, // Adjust the connection flow control window if needed. if delta := uint32(icwz - defaultWindowSize); delta > 0 { if err := t.framer.WriteWindowUpdate(0, delta); err != nil { - err = connectionErrorf(true, err, "transport: failed to write window update: %v", err) + err = status.Newf(codes.Unavailable, "transport: failed to write window update: %v", err). + WithMappingErr(gerrors.ErrEstablishConnection).Err() t.Close(err) return nil, err } @@ -431,9 +436,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea if state := t.state; state != reachable { t.mu.Unlock() // Do a quick cleanup. - err := error(errStreamDrain) + err := errStatusStreamDrain if state == closing { - err = ErrConnClosing + err = errStatusConnClosing + // make sure the error exposed to users is *status.Error + cleanup(err) + return ErrConnClosing } cleanup(err) return err @@ -516,7 +524,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea case <-s.ctx.Done(): return nil, ContextErr(s.ctx.Err()) case <-t.goAway: - return nil, errStreamDrain + return nil, errStatusStreamDrain case <-t.ctx.Done(): return nil, ErrConnClosing } @@ -534,6 +542,7 @@ func (t *http2Client) CloseStream(s *Stream, err error) { if err != nil { rst = true rstCode = http2.ErrCodeCancel + klog.CtxInfof(s.ctx, "KITEX: stream closed by ctx canceled, err: %v"+sendRSTStreamFrameSuffix, err) } t.closeStream(s, err, rst, rstCode, status.Convert(err), nil, false) } @@ -557,11 +566,24 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2. // This will unblock reads eventually. s.write(recvMsg{err: err}) } + + // store closeStreamErr + storeErr := err + if err == io.EOF { + storeErr = st.Err() + } + if storeErr != nil { + s.closeStreamErr.Store(storeErr) + } + // If headerChan isn't closed, then close it. if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) { s.noHeaders = true close(s.headerChan) } + if rst && isCustomRstCodeEnabled() { + rstCode = getRstCode(err) + } cleanup := &cleanupStream{ streamID: s.id, onWrite: func() { @@ -617,7 +639,7 @@ func (t *http2Client) Close(err error) error { t.kpDormancyCond.Signal() } t.mu.Unlock() - t.controlBuf.finish() + t.controlBuf.finish(err) t.cancel() cErr := t.conn.Close() @@ -644,7 +666,8 @@ func (t *http2Client) GracefulClose() { active := len(t.activeStreams) t.mu.Unlock() if active == 0 { - t.Close(connectionErrorf(true, nil, "no active streams left to process while draining")) + t.Close(status.New(codes.Unavailable, "no active streams left to process while draining"). + WithMappingErr(gerrors.ErrNoActiveStream).Err()) return } t.controlBuf.put(&incomingGoAway{}) @@ -656,10 +679,10 @@ func (t *http2Client) Write(s *Stream, hdr, data []byte, opts *Options) error { if opts.Last { // If it's the last message, update stream state. if !s.compareAndSwapState(streamActive, streamWriteDone) { - return errStreamDone + return s.getCloseStreamErr() } } else if s.getState() != streamActive { - return errStreamDone + return s.getCloseStreamErr() } df := newDataFrame() df.streamID = s.id @@ -670,7 +693,7 @@ func (t *http2Client) Write(s *Stream, hdr, data []byte, opts *Options) error { df.originD = df.d if hdr != nil || data != nil { // If it's not an empty data frame, check quota. if err := s.wq.get(int32(len(hdr) + len(data))); err != nil { - return err + return s.getCloseStreamErr() } } return t.controlBuf.put(df) @@ -766,7 +789,10 @@ func (t *http2Client) handleData(f *grpcframe.DataFrame) { } if size > 0 { if err := s.fc.onData(size); err != nil { - t.closeStream(s, io.EOF, true, http2.ErrCodeFlowControl, status.New(codes.Internal, err.Error()), nil, false) + klog.CtxErrorf(s.ctx, "KITEX: http2Client.handleData inflow control err: %v, code: %d"+sendRSTStreamFrameSuffix, err, http2.ErrCodeFlowControl) + st := status.New(codes.Internal, err.Error()). + WithMappingErr(gerrors.ErrStreamFlowControl) + t.closeStream(s, io.EOF, true, http2.ErrCodeFlowControl, st, nil, false) return } if f.Header().Flags.Has(http2.FlagDataPadded) { @@ -787,7 +813,9 @@ func (t *http2Client) handleData(f *grpcframe.DataFrame) { // The server has closed the stream without sending trailers. Record that // the read direction is closed, and set the status appropriately. if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) { - t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true) + st := status.New(codes.Internal, "server closed the stream without sending trailers"). + WithMappingErr(gerrors.ErrClosedWithoutTrailer) + t.closeStream(s, io.EOF, false, http2.ErrCodeNo, st, nil, true) } } @@ -800,19 +828,12 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) { // The stream was unprocessed by the server. atomic.StoreUint32(&s.unprocessed, 1) } - statusCode, ok := http2ErrConvTab[f.ErrCode] - if !ok { - klog.Warnf("transport: http2Client.handleRSTStream found no mapped gRPC status for the received nhttp2 error %v", f.ErrCode) - statusCode = codes.Unknown - } - if statusCode == codes.Canceled { - if d, ok := s.ctx.Deadline(); ok && !d.After(time.Now()) { - // Our deadline was already exceeded, and that was likely the cause - // of this cancelation. Alter the status code accordingly. - statusCode = codes.DeadlineExceeded - } - } - t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode), nil, false) + + mappingErr, stCode := getMappingErrAndStatusCode(s.ctx, f.ErrCode) + + st := status.Newf(stCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode). + WithMappingErr(mappingErr) + t.closeStream(s, io.EOF, false, http2.ErrCodeNo, st, nil, false) } func (t *http2Client) handleSettings(f *grpcframe.SettingsFrame, isFirst bool) { @@ -889,7 +910,9 @@ func (t *http2Client) handleGoAway(f *grpcframe.GoAwayFrame) { id := f.LastStreamID if id > 0 && id%2 != 1 { t.mu.Unlock() - t.Close(connectionErrorf(true, nil, "received goaway with non-zero even-numbered numbered stream id: %v", id)) + st := status.Newf(codes.Unavailable, "received goaway with non-zero even-numbered numbered stream id: %v", id). + WithMappingErr(gerrors.ErrHandleGoAway) + t.Close(st.Err()) return } // A client can receive multiple GoAways from the server (see @@ -907,7 +930,9 @@ func (t *http2Client) handleGoAway(f *grpcframe.GoAwayFrame) { // If there are multiple GoAways the first one should always have an ID greater than the following ones. if id > t.prevGoAwayID { t.mu.Unlock() - t.Close(connectionErrorf(true, nil, "received goaway with stream id: %v, which exceeds stream id of previous goaway: %v", id, t.prevGoAwayID)) + st := status.Newf(codes.Unavailable, "received goaway with stream id: %v, which exceeds stream id of previous goaway: %v", id, t.prevGoAwayID). + WithMappingErr(gerrors.ErrHandleGoAway) + t.Close(st.Err()) return } default: @@ -932,14 +957,18 @@ func (t *http2Client) handleGoAway(f *grpcframe.GoAwayFrame) { if streamID > id && streamID <= upperLimit { // The stream was unprocessed by the server. atomic.StoreUint32(&stream.unprocessed, 1) - t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false) + st := status.New(codes.Unavailable, "the stream is rejected because server is draining the connection"). + WithMappingErr(gerrors.ErrStreamDrain) + t.closeStream(stream, st.Err(), false, http2.ErrCodeNo, st, nil, false) } } t.prevGoAwayID = id active := len(t.activeStreams) t.mu.Unlock() if active == 0 { - t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams")) + st := status.New(codes.Unavailable, "received goaway and there are no active streams"). + WithMappingErr(gerrors.ErrNoActiveStream) + t.Close(st.Err()) } } @@ -982,7 +1011,9 @@ func (t *http2Client) operateHeaders(frame *grpcframe.MetaHeadersFrame) { if !initialHeader && !endStream { // As specified by gRPC over HTTP2, a HEADERS frame (and associated CONTINUATION frames) can only appear at the start or end of a stream. Therefore, second HEADERS frame must have EOS bit set. - st := status.New(codes.Internal, "a HEADERS frame cannot appear in the middle of a stream") + st := status.New(codes.Internal, "a HEADERS frame cannot appear in the middle of a stream"). + WithMappingErr(gerrors.ErrMiddleHeader) + klog.CtxErrorf(s.ctx, "KITEX: http2Client.operateHeaders received HEADERS frame in the middle of a stream"+sendRSTStreamFrameSuffix) t.closeStream(s, st.Err(), true, http2.ErrCodeProtocol, st, nil, false) return } @@ -990,8 +1021,10 @@ func (t *http2Client) operateHeaders(frame *grpcframe.MetaHeadersFrame) { state := &decodeState{} // Initialize isGRPC value to be !initialHeader, since if a gRPC Response-Headers has already been received, then it means that the peer is speaking gRPC and we are in gRPC mode. state.data.isGRPC = !initialHeader - if err := state.decodeHeader(frame); err != nil { - t.closeStream(s, err, true, http2.ErrCodeProtocol, status.Convert(err), nil, endStream) + if st := state.decodeHeader(frame); st != nil { + klog.CtxErrorf(s.ctx, "KITEX: http2Client.operateHeaders decode HEADERS frame failed, err: %v, code: %d"+sendRSTStreamFrameSuffix, st.Err(), http2.ErrCodeProtocol) + st = st.WithMappingErr(gerrors.ErrDecodeHeader) + t.closeStream(s, st.Err(), true, http2.ErrCodeProtocol, st, nil, endStream) return } @@ -1034,8 +1067,9 @@ func (t *http2Client) reader() { // Check the validity of server preface. frame, err := t.framer.ReadFrame() if err != nil { - err = connectionErrorf(true, err, "error reading from server, remoteAddress=%s, error=%v", t.conn.RemoteAddr(), err) - t.Close(err) // this kicks off resetTransport, so must be last before return + st := status.Newf(codes.Unavailable, "error reading from server, remoteAddress=%s, error=%v", t.conn.RemoteAddr(), err). + WithMappingErr(gerrors.ErrEstablishConnection) + t.Close(st.Err()) // this kicks off resetTransport, so must be last before return return } t.conn.SetReadDeadline(time.Time{}) // reset deadline once we get the settings frame (we didn't time out, yay!) @@ -1044,8 +1078,9 @@ func (t *http2Client) reader() { } sf, ok := frame.(*grpcframe.SettingsFrame) if !ok { - err = connectionErrorf(true, err, "first frame received is not a setting frame") - t.Close(err) // this kicks off resetTransport, so must be last before return + st := status.New(codes.Unavailable, "first frame received is not a setting frame"). + WithMappingErr(gerrors.ErrEstablishConnection) + t.Close(st.Err()) // this kicks off resetTransport, so must be last before return return } t.handleSettings(sf, true) @@ -1073,13 +1108,16 @@ func (t *http2Client) reader() { if err != nil { msg = err.Error() } - t.closeStream(s, status.New(code, msg).Err(), true, http2.ErrCodeProtocol, status.New(code, msg), nil, false) + klog.CtxErrorf(s.ctx, "KITEX: http2Client.reader encountered http2.StreamError: %v, code: %d"+sendRSTStreamFrameSuffix, se, http2.ErrCodeProtocol) + st := status.New(code, msg).WithMappingErr(gerrors.ErrHTTP2Stream) + t.closeStream(s, st.Err(), true, http2.ErrCodeProtocol, st, nil, false) } continue } else { // Transport error. - err = connectionErrorf(true, err, "error reading from server, remoteAddress=%s, error=%v", t.conn.RemoteAddr(), err) - t.Close(err) + st := status.Newf(codes.Unavailable, "error reading from server, remoteAddress=%s, error=%v", t.conn.RemoteAddr(), err). + WithMappingErr(gerrors.ErrHTTP2Connection) + t.Close(st.Err()) return } } @@ -1137,7 +1175,9 @@ func (t *http2Client) keepalive() { continue } if outstandingPing && timeoutLeft <= 0 { - t.Close(connectionErrorf(true, nil, "keepalive ping failed to receive ACK within timeout")) + st := status.New(codes.Unavailable, "keepalive ping failed to receive ACK within timeout"). + WithMappingErr(gerrors.ErrKeepAlive) + t.Close(st.Err()) return } t.mu.Lock() diff --git a/pkg/remote/trans/nphttp2/grpc/http2_server.go b/pkg/remote/trans/nphttp2/grpc/http2_server.go index c2b84efe13..e7ca91c652 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_server.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_server.go @@ -34,7 +34,9 @@ import ( "sync/atomic" "time" + "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/remote/codec/protobuf/encoding" + gerrors "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/errors" "github.com/cloudwego/netpoll" "golang.org/x/net/http2" @@ -53,20 +55,24 @@ import ( var ( // ErrIllegalHeaderWrite indicates that setting header is illegal because of // the stream's state. - ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHeader was already called") + ErrIllegalHeaderWrite = errors.New("transport: WriteHeader was already called") + errStatusIllegalHeaderWrite = status.New(codes.Internal, ErrIllegalHeaderWrite.Error()+triggeredByHandlerSideSuffix). + WithMappingErr(gerrors.ErrIllegalHeaderWrite).Err() // ErrHeaderListSizeLimitViolation indicates that the header list size is larger // than the limit set by peer. - ErrHeaderListSizeLimitViolation = errors.New("transport: trying to send header list size larger than the limit set by peer") + ErrHeaderListSizeLimitViolation = errors.New("transport: trying to send header list size larger than the limit set by peer") + errStatusHeaderListSizeLimitViolation = status.New(codes.Internal, ErrHeaderListSizeLimitViolation.Error()+triggeredByHandlerSideSuffix). + WithMappingErr(kerrors.ErrMetaSizeExceeded).Err() // errors used for cancelling stream. // the code should be codes.Canceled coz it's NOT returned from remote - errConnectionEOF = status.New(codes.Canceled, "transport: connection EOF").Err() - errStreamClosing = status.New(codes.Canceled, "transport: stream is closing").Err() - errMaxStreamsExceeded = status.New(codes.Canceled, "transport: max streams exceeded").Err() - errNotReachable = status.New(codes.Canceled, "transport: server not reachable").Err() - errMaxAgeClosing = status.New(codes.Canceled, "transport: closing server transport due to maximum connection age").Err() - errIdleClosing = status.New(codes.Canceled, "transport: closing server transport due to idleness").Err() + errStatusConnectionEOF = status.New(codes.Canceled, "transport: connection EOF"+triggeredByRemoteServiceSuffix). + WithMappingErr(gerrors.ErrHTTP2Connection).Err() + errStatusMaxStreamsExceeded = status.New(codes.Canceled, "transport: max streams exceeded"+triggeredByRemoteServiceSuffix). + WithMappingErr(gerrors.ErrMaxStreamExceeded).Err() + errStatusNotReachable = status.New(codes.Canceled, "transport: server not reachable"+triggeredByRemoteServiceSuffix). + WithMappingErr(gerrors.ErrNotReachable).Err() ) func init() { @@ -177,14 +183,18 @@ func newHTTP2Server(ctx context.Context, conn net.Conn, config *ServerConfig) (_ } if err := framer.WriteSettings(isettings...); err != nil { - return nil, connectionErrorf(false, err, "transport: %v", err) + st := status.Newf(codes.Unavailable, "transport: server failed to write initial settings frame: %v", err). + WithMappingErr(gerrors.ErrEstablishConnection) + return nil, st.Err() } // Adjust the connection flow control window if needed. if icwz > defaultWindowSize { if delta := icwz - defaultWindowSize; delta > 0 { if err := framer.WriteWindowUpdate(0, delta); err != nil { - return nil, connectionErrorf(false, err, "transport: %v", err) + st := status.Newf(codes.Unavailable, "transport: server failed to write window update frame: %v", err). + WithMappingErr(gerrors.ErrEstablishConnection) + return nil, st.Err() } } } @@ -247,32 +257,42 @@ func newHTTP2Server(ctx context.Context, conn net.Conn, config *ServerConfig) (_ // Check the validity of client preface. preface := make([]byte, len(ClientPreface)) - if _, err := io.ReadFull(t.conn, preface); err != nil { + if _, rErr := io.ReadFull(t.conn, preface); rErr != nil { // In deployments where a gRPC server runs behind a cloud load balancer // which performs regular TCP level health checks, the connection is // closed immediately by the latter. Returning io.EOF here allows the // grpc server implementation to recognize this scenario and suppress // logging to reduce spam. - if err == io.EOF { + if rErr == io.EOF { return nil, io.EOF } - return nil, connectionErrorf(false, err, "transport: http2Server.HandleStreams failed to receive the preface from client: %v", err) + err = status.Newf(codes.Unavailable, "transport: server failed to receive the preface from client: %v", rErr). + WithMappingErr(gerrors.ErrEstablishConnection).Err() + return nil, err } if !bytes.Equal(preface, ClientPreface) { - return nil, connectionErrorf(false, nil, "transport: http2Server.HandleStreams received bogus greeting from client: %q", preface) + err = status.Newf(codes.Unavailable, "transport: server received bogus greeting from client: %q", preface). + WithMappingErr(gerrors.ErrEstablishConnection).Err() + return nil, err } - frame, err := t.framer.ReadFrame() - if err == io.EOF || err == io.ErrUnexpectedEOF { + frame, rErr := t.framer.ReadFrame() + if rErr == io.EOF || rErr == io.ErrUnexpectedEOF { + err = status.Newf(codes.Unavailable, "transport: connection EOF: %v", err). + WithMappingErr(gerrors.ErrEstablishConnection).Err() return nil, err } - if err != nil { - return nil, connectionErrorf(false, err, "transport: http2Server.HandleStreams failed to read initial settings frame: %v", err) + if rErr != nil { + err = status.Newf(codes.Unavailable, "transport: server failed to read initial settings frame: %v", err). + WithMappingErr(gerrors.ErrEstablishConnection).Err() + return nil, err } atomic.StoreInt64(&t.lastRead, time.Now().UnixNano()) sf, ok := frame.(*grpcframe.SettingsFrame) if !ok { - return nil, connectionErrorf(false, nil, "transport: http2Server.HandleStreams saw invalid preface type %T from client", frame) + err = status.Newf(codes.Unavailable, "transport: server received invalid frame type %T from client", frame). + WithMappingErr(gerrors.ErrEstablishConnection).Err() + return nil, err } t.handleSettings(sf) @@ -292,20 +312,21 @@ func newHTTP2Server(ctx context.Context, conn net.Conn, config *ServerConfig) (_ // operateHeaders takes action on the decoded headers. Returns an error if fatal // error encountered and transport needs to close, otherwise returns nil. +// users will only be able to perceive the stream if the input handle is executed. func (t *http2Server) operateHeaders(frame *grpcframe.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) error { streamID := frame.Header().StreamID state := &decodeState{ serverSide: true, } - if err := state.decodeHeader(frame); err != nil { - if se, ok := status.FromError(err); ok { - t.controlBuf.put(&cleanupStream{ - streamID: streamID, - rst: true, - rstCode: statusCodeConvTab[se.Code()], - onWrite: func() {}, - }) - } + if st := state.decodeHeader(frame); st != nil { + rstCode := statusCodeConvTab[st.Code()] + klog.CtxErrorf(t.ctx, "KITEX: http2Server.operateHeaders failed to decode header frame, err=%v, code: %d"+sendRSTStreamFrameSuffix, st.Err(), rstCode) + t.controlBuf.put(&cleanupStream{ + streamID: streamID, + rst: true, + rstCode: rstCode, + onWrite: func() {}, + }) return nil } @@ -330,7 +351,7 @@ func (t *http2Server) operateHeaders(frame *grpcframe.MetaHeadersFrame, handle f } else { s.ctx, cancel = context.WithCancel(t.ctx) } - s.ctx, s.cancel = newContextWithCancelReason(s.ctx, cancel) + s.ctx, s.cancelFunc = newContextWithCancelReason(s.ctx, cancel) // Attach the received metadata to the context. if len(state.data.mdata) > 0 { s.ctx = metadata.NewIncomingContext(s.ctx, state.data.mdata) @@ -339,18 +360,19 @@ func (t *http2Server) operateHeaders(frame *grpcframe.MetaHeadersFrame, handle f t.mu.Lock() if t.state != reachable { t.mu.Unlock() - s.cancel(errNotReachable) + s.cancel(errStatusNotReachable) return nil } if uint32(len(t.activeStreams)) >= t.maxStreams { t.mu.Unlock() + klog.CtxErrorf(t.ctx, "KITEX: http2Server.operateHeaders failed to create stream, err=%v, code: %d"+sendRSTStreamFrameSuffix, errStatusMaxStreamsExceeded, http2.ErrCodeRefusedStream) t.controlBuf.put(&cleanupStream{ streamID: streamID, rst: true, rstCode: http2.ErrCodeRefusedStream, onWrite: func() {}, }) - s.cancel(errMaxStreamsExceeded) + s.cancel(errStatusMaxStreamsExceeded) return nil } if streamID%2 != 1 || streamID <= t.maxStreamID { @@ -405,11 +427,12 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context. s := t.activeStreams[se.StreamID] t.mu.Unlock() if s != nil { - // it will be codes.Internal error for GRPC - // TODO: map http2.StreamError to status.Error? - s.cancel(err) - t.closeStream(s, true, se.Code, false) + klog.CtxErrorf(s.ctx, "KITEX: http2Server.HandleStreams encountered http2.StreamError, err=%v, code: %d"+sendRSTStreamFrameSuffix, err, se.Code) + stErr := status.Newf(codes.Canceled, "transport: ReadFrame encountered http2.StreamError: %v [triggered by %s]", err, s.getSourceService()). + WithMappingErr(gerrors.ErrHTTP2Stream).Err() + t.closeStream(s, stErr, true, se.Code, false) } else { + klog.CtxErrorf(t.ctx, "KITEX: http2Server.HandleStreams failed to ReadFrame, err=%v, code: %d"+sendRSTStreamFrameSuffix, err, se.Code) t.controlBuf.put(&cleanupStream{ streamID: se.StreamID, rst: true, @@ -420,18 +443,22 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context. continue } if err == io.EOF || err == io.ErrUnexpectedEOF || errors.Is(err, netpoll.ErrEOF) { - t.closeWithErr(errConnectionEOF) + t.closeWithErr(errStatusConnectionEOF) return } klog.CtxWarnf(t.ctx, "transport: http2Server.HandleStreams failed to read frame: %v", err) - t.closeWithErr(err) + stErr := status.Newf(codes.Canceled, "transport: ReadFrame encountered err: %v"+triggeredByRemoteServiceSuffix, err). + WithMappingErr(gerrors.ErrHTTP2Connection).Err() + t.closeWithErr(stErr) return } switch frame := frame.(type) { case *grpcframe.MetaHeadersFrame: if err := t.operateHeaders(frame, handle, traceCtx); err != nil { klog.CtxErrorf(t.ctx, "transport: http2Server.HandleStreams fatal err: %v", err) - t.closeWithErr(err) + stErr := status.New(codes.Canceled, err.Error()). + WithMappingErr(gerrors.ErrOperateHeaders).Err() + t.closeWithErr(stErr) break } case *grpcframe.DataFrame: @@ -551,7 +578,10 @@ func (t *http2Server) handleData(f *grpcframe.DataFrame) { } if size > 0 { if err := s.fc.onData(size); err != nil { - t.closeStream(s, true, http2.ErrCodeFlowControl, false) + klog.CtxErrorf(s.ctx, "KITEX: http2Server.handleData inflow control err: %v, code: %d"+sendRSTStreamFrameSuffix, err, http2.ErrCodeFlowControl) + stErr := status.Newf(codes.Canceled, "transport: inflow control err: %v [triggered by %s]", err, s.getSourceService()). + WithMappingErr(gerrors.ErrStreamFlowControl).Err() + t.closeStream(s, stErr, true, http2.ErrCodeFlowControl, false) return } if f.Header().Flags.Has(http2.FlagDataPadded) { @@ -579,7 +609,11 @@ func (t *http2Server) handleData(f *grpcframe.DataFrame) { func (t *http2Server) handleRSTStream(f *http2.RSTStreamFrame) { // If the stream is not deleted from the transport's active streams map, then do a regular close stream. if s, ok := t.getStream(f); ok { - t.closeStream(s, false, 0, false) + mappingErr, stCode := getMappingErrAndStatusCode(s.ctx, f.ErrCode) + stErr := status.Newf(stCode, "transport: RSTStream Frame received with error code: %d [triggered by %s]", f.ErrCode, s.getSourceService()). + WithMappingErr(mappingErr).Err() + klog.CtxInfof(s.ctx, "transport: http2Server.handleRSTStream received RSTStream Frame with error code: %v", f.ErrCode) + t.closeStream(s, stErr, false, 0, false) return } // If the stream is already deleted from the active streams map, then put a cleanupStream item into controlbuf to delete the stream from loopy writer's established streams map. @@ -589,6 +623,7 @@ func (t *http2Server) handleRSTStream(f *http2.RSTStreamFrame) { rstCode: 0, onWrite: func() {}, }) + // since we do not need to send RstStream Frame, do not add log here } func (t *http2Server) handleSettings(f *grpcframe.SettingsFrame) { @@ -711,8 +746,11 @@ func (t *http2Server) checkForHeaderListSize(it interface{}) bool { // WriteHeader sends the header metadata md back to the client. func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { - if s.updateHeaderSent() || s.getState() == streamDone { - return ErrIllegalHeaderWrite + if s.updateHeaderSent() { + return errStatusIllegalHeaderWrite + } + if s.getState() == streamDone { + return ContextErr(s.ctx.Err()) } s.hdrMu.Lock() if md.Len() > 0 { @@ -756,8 +794,10 @@ func (t *http2Server) writeHeaderLocked(s *Stream) error { if err != nil { return err } - t.closeStream(s, true, http2.ErrCodeInternal, false) - return ErrHeaderListSizeLimitViolation + klog.CtxErrorf(s.ctx, "KITEX: http2Server.writeHeaderLocked checkForHeaderListSize failed, code: %d"+sendRSTStreamFrameSuffix, http2.ErrCodeInternal) + stErr := errStatusHeaderListSizeLimitViolation + t.closeStream(s, stErr, true, http2.ErrCodeInternal, false) + return stErr } return nil } @@ -819,8 +859,9 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { if err != nil { return err } - t.closeStream(s, true, http2.ErrCodeInternal, false) - return ErrHeaderListSizeLimitViolation + stErr := errStatusHeaderListSizeLimitViolation + t.closeStream(s, stErr, true, http2.ErrCodeInternal, false) + return stErr } // Send a RST_STREAM after the trailers if the client has not already half-closed. rst := s.getState() == streamActive @@ -838,13 +879,6 @@ func (t *http2Server) Write(s *Stream, hdr, data []byte, opts *Options) error { } else { // Writing headers checks for this condition. if s.getState() == streamDone { - // TODO(mmukhi, dfawley): Should the server write also return io.EOF? - s.cancel(errStreamClosing) - select { - case <-t.done: - return ErrConnClosing - default: - } return ContextErr(s.ctx.Err()) } } @@ -856,11 +890,6 @@ func (t *http2Server) Write(s *Stream, hdr, data []byte, opts *Options) error { df.originD = df.d df.resetPingStrikes = &t.resetPingStrikes if err := s.wq.get(int32(len(hdr) + len(data))); err != nil { - select { - case <-t.done: - return ErrConnClosing - default: - } return ContextErr(s.ctx.Err()) } return t.controlBuf.put(df) @@ -921,7 +950,9 @@ func (t *http2Server) keepalive() { case <-ageTimer.C: // Close the connection after grace period. klog.Infof("transport: closing server transport due to maximum connection age.") - t.closeWithErr(errMaxAgeClosing) + stErr := status.New(codes.Canceled, "transport: closing server transport due to maximum connection age"+triggeredByRemoteServiceSuffix). + WithMappingErr(gerrors.ErrKeepAlive).Err() + t.closeWithErr(stErr) case <-t.done: } return @@ -938,7 +969,9 @@ func (t *http2Server) keepalive() { } if outstandingPing && kpTimeoutLeft <= 0 { klog.Infof("transport: closing server transport due to idleness.") - t.closeWithErr(errIdleClosing) + stErr := status.New(codes.Canceled, "transport: closing server transport due to idleness"+triggeredByRemoteServiceSuffix). + WithMappingErr(gerrors.ErrKeepAlive).Err() + t.closeWithErr(stErr) return } if !outstandingPing { @@ -976,7 +1009,7 @@ func (t *http2Server) closeWithErr(reason error) error { streams := t.activeStreams t.activeStreams = nil t.mu.Unlock() - t.controlBuf.finish() + t.controlBuf.finish(reason) close(t.done) err := t.conn.Close() @@ -990,11 +1023,6 @@ func (t *http2Server) closeWithErr(reason error) error { // deleteStream deletes the stream s from transport's active streams. func (t *http2Server) deleteStream(s *Stream, eosReceived bool) { - // In case stream sending and receiving are invoked in separate - // goroutines (e.g., bi-directional streaming), cancel needs to be - // called to interrupt the potential blocking on other goroutines. - s.cancel(nil) // more details about the reason? - t.mu.Lock() if _, ok := t.activeStreams[s.id]; ok { delete(t.activeStreams, s.id) @@ -1012,7 +1040,11 @@ func (t *http2Server) finishStream(s *Stream, rst bool, rstCode http2.ErrCode, h // If the stream was already done, return. return } + s.cancel(nil) + if isCustomRstCodeEnabled() { + rstCode = getRstCode(nil) + } hdr.cleanup = &cleanupStream{ streamID: s.id, rst: rst, @@ -1025,10 +1057,17 @@ func (t *http2Server) finishStream(s *Stream, rst bool, rstCode http2.ErrCode, h } // closeStream clears the footprint of a stream when the stream is not needed any more. -func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, eosReceived bool) { +func (t *http2Server) closeStream(s *Stream, err error, rst bool, rstCode http2.ErrCode, eosReceived bool) { + // In case stream sending and receiving are invoked in separate + // goroutines (e.g., bi-directional streaming), cancel needs to be + // called to interrupt the potential blocking on other goroutines. + s.cancel(err) s.swapState(streamDone) t.deleteStream(s, eosReceived) + if rst && isCustomRstCodeEnabled() { + rstCode = getRstCode(err) + } t.controlBuf.put(&cleanupStream{ streamID: s.id, rst: rst, diff --git a/pkg/remote/trans/nphttp2/grpc/http_util.go b/pkg/remote/trans/nphttp2/grpc/http_util.go index ecb6bb7ec1..aa12e826ab 100644 --- a/pkg/remote/trans/nphttp2/grpc/http_util.go +++ b/pkg/remote/trans/nphttp2/grpc/http_util.go @@ -22,8 +22,11 @@ package grpc import ( "bytes" + "context" "encoding/base64" + "errors" "fmt" + "io" "math" "net/http" "strconv" @@ -39,6 +42,7 @@ import ( "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" + gerrors "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/errors" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc/grpcframe" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status" "github.com/cloudwego/kitex/pkg/utils" @@ -141,8 +145,8 @@ type parsedHeaderData struct { // Otherwise (i.e. a content-type string starts without "application/grpc", or does not exist), we // are in HTTP fallback mode, and should handle error specific to HTTP. isGRPC bool - grpcErr error - httpErr error + grpcStatus *status.Status + httpErrStatus *status.Status contentTypeErr string } @@ -289,11 +293,11 @@ func decodeMetadataHeader(k, v string) (string, error) { return v, nil } -func (d *decodeState) decodeHeader(frame *grpcframe.MetaHeadersFrame) error { +func (d *decodeState) decodeHeader(frame *grpcframe.MetaHeadersFrame) *status.Status { // frame.Truncated is set to true when framer detects that the current header // list size hits MaxHeaderListSize limit. if frame.Truncated { - return status.New(codes.Internal, "peer header list size exceeded limit").Err() + return status.New(codes.Internal, "peer header list size exceeded limit") } for _, hf := range frame.Fields { @@ -301,8 +305,8 @@ func (d *decodeState) decodeHeader(frame *grpcframe.MetaHeadersFrame) error { } if d.data.isGRPC { - if d.data.grpcErr != nil { - return d.data.grpcErr + if d.data.grpcStatus != nil { + return d.data.grpcStatus } if d.serverSide { return nil @@ -321,8 +325,8 @@ func (d *decodeState) decodeHeader(frame *grpcframe.MetaHeadersFrame) error { } // HTTP fallback mode - if d.data.httpErr != nil { - return d.data.httpErr + if d.data.httpErrStatus != nil { + return d.data.httpErrStatus } var ( @@ -337,7 +341,7 @@ func (d *decodeState) decodeHeader(frame *grpcframe.MetaHeadersFrame) error { } } - return status.New(code, d.constructHTTPErrMsg()).Err() + return status.New(code, d.constructHTTPErrMsg()) } // constructErrMsg constructs error message to be returned in HTTP fallback mode. @@ -389,7 +393,7 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) { case "grpc-status": code, err := strconv.Atoi(f.Value) if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status: %v", err) + d.data.grpcStatus = status.Newf(codes.Internal, "transport: malformed grpc-status: %v", err) return } d.data.rawStatusCode = &code @@ -398,26 +402,26 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) { case "biz-status": code, err := strconv.Atoi(f.Value) if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed biz-status: %v", err) + d.data.grpcStatus = status.Newf(codes.Internal, "transport: malformed biz-status: %v", err) return } d.data.bizStatusCode = &code case "biz-extra": extra, err := utils.JSONStr2Map(f.Value) if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed biz-extra: %v", err) + d.data.grpcStatus = status.Newf(codes.Internal, "transport: malformed biz-extra: %v", err) return } d.data.bizStatusExtra = extra case "grpc-status-details-bin": v, err := decodeBinHeader(f.Value) if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) + d.data.grpcStatus = status.Newf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) return } s := &spb.Status{} if err := proto.Unmarshal(v, s); err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) + d.data.grpcStatus = status.Newf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) return } d.data.statusGen = status.FromProto(s) @@ -425,21 +429,21 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) { d.data.timeoutSet = true var err error if d.data.timeout, err = decodeTimeout(f.Value); err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed time-out: %v", err) + d.data.grpcStatus = status.Newf(codes.Internal, "transport: malformed time-out: %v", err) } case ":path": d.data.method = f.Value case ":status": code, err := strconv.Atoi(f.Value) if err != nil { - d.data.httpErr = status.Errorf(codes.Internal, "transport: malformed http-status: %v", err) + d.data.httpErrStatus = status.Newf(codes.Internal, "transport: malformed http-status: %v", err) return } d.data.httpStatus = &code case "grpc-tags-bin": v, err := decodeBinHeader(f.Value) if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-tags-bin: %v", err) + d.data.grpcStatus = status.Newf(codes.Internal, "transport: malformed grpc-tags-bin: %v", err) return } d.data.statsTags = v @@ -447,7 +451,7 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) { case "grpc-trace-bin": v, err := decodeBinHeader(f.Value) if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-trace-bin: %v", err) + d.data.grpcStatus = status.Newf(codes.Internal, "transport: malformed grpc-trace-bin: %v", err) return } d.data.statsTrace = v @@ -642,3 +646,96 @@ func decodeGrpcMessageUnchecked(msg string) string { } return buf.String() } + +var ( + rstCode2MappingErrMap = map[http2.ErrCode]error{} + mappingErr2RstCodeMap = map[error]http2.ErrCode{} +) + +var customRstCodeEnabled = false + +// SetCustomRstCodeEnabled enables/disables the custom RstCode and err mapping. +// it is off by default. +func SetCustomRstCodeEnabled(flag bool) { + customRstCodeEnabled = flag +} + +// RegisterCustomRstCode registers a mapping between a custom RstCode and a mapping err. +// The mapping only works when SetCustomRstCodeEnabled(true) is invoked. +// e.g. RegisterCustomRstCode(1000, kerrors.ErrGracefulShutdown) +// When Kitex receives a RST_STREAM frame with error code 1000, it will inject kerrors.ErrGracefulShutdown into +// *status.Error that users will conceive +func RegisterCustomRstCode(rstCode uint32, mappingErr error) { + h2Code := http2.ErrCode(rstCode) + rstCode2MappingErrMap[h2Code] = mappingErr + mappingErr2RstCodeMap[mappingErr] = h2Code +} + +func isCustomRstCodeEnabled() bool { + return customRstCodeEnabled +} + +func cleanupCustomRstCodeMapping() { + rstCode2MappingErrMap = map[http2.ErrCode]error{} + mappingErr2RstCodeMap = map[error]http2.ErrCode{} +} + +func getMappingErrAndStatusCode(ctx context.Context, rstCode http2.ErrCode) (error, codes.Code) { + mappingErr := getMappingErr(rstCode) + stCode := getStatusCode(ctx, rstCode, mappingErr) + return mappingErr, stCode +} + +func getMappingErr(rstCode http2.ErrCode) error { + err, ok := rstCode2MappingErrMap[rstCode] + if !isCustomRstCodeEnabled() || !ok { + return gerrors.ErrRecvRstStream + } + return err +} + +func getStatusCode(ctx context.Context, errCode http2.ErrCode, mappingErr error) codes.Code { + stCode, ok := http2ErrConvTab[errCode] + if !ok { + if !isCustomRstCodeEnabled() { + klog.Warnf("transport: http2Client.handleRSTStream found no mapped gRPC status for the received nhttp2 error %v", errCode) + stCode = codes.Unknown + } else { + stCode = codes.Canceled + if errors.Is(mappingErr, kerrors.ErrGracefulShutdown) { + stCode = codes.Unavailable + } + if errors.Is(mappingErr, kerrors.ErrStreamTimeout) { + stCode = codes.DeadlineExceeded + } + } + } + if stCode == codes.Canceled { + if d, ok := ctx.Deadline(); ok && !d.After(time.Now()) { + // Our deadline was already exceeded, and that was likely the cause + // of this cancelation. Alter the status code accordingly. + stCode = codes.DeadlineExceeded + } + } + return stCode +} + +func getRstCode(err error) (rstCode http2.ErrCode) { + if err == nil || err == io.EOF { + return http2.ErrCodeNo + } + rstCode = http2.ErrCodeCancel + statusErr, ok := err.(*status.Error) + if !ok { + return + } + mappingErr := statusErr.GetMappingErr() + if mappingErr == nil { + return + } + code, ok := mappingErr2RstCodeMap[mappingErr] + if !ok { + return + } + return code +} diff --git a/pkg/remote/trans/nphttp2/grpc/http_util_test.go b/pkg/remote/trans/nphttp2/grpc/http_util_test.go index 9cd1592413..d89859b444 100644 --- a/pkg/remote/trans/nphttp2/grpc/http_util_test.go +++ b/pkg/remote/trans/nphttp2/grpc/http_util_test.go @@ -19,10 +19,17 @@ package grpc import ( "bytes" "context" + "io" "testing" "time" + "golang.org/x/net/http2" + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" + gerrors "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/errors" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status" ) func TestEncoding(t *testing.T) { @@ -212,3 +219,188 @@ func TestConnectionError(t *testing.T) { ori = connectionError.Origin() test.Assert(t, ori == context.Canceled) } + +func Test_getMappingErrAndStatusCode(t *testing.T) { + testcases := []struct { + desc string + input []http2.ErrCode + want []struct { + err error + stCode codes.Code + } + setup func(t *testing.T) + clean func(t *testing.T) + }{ + { + desc: "normal RstCode", + input: []http2.ErrCode{ + http2.ErrCodeNo, + http2.ErrCodeCancel, + }, + want: []struct { + err error + stCode codes.Code + }{ + { + err: gerrors.ErrRecvRstStream, + stCode: codes.Internal, + }, + { + err: gerrors.ErrRecvRstStream, + stCode: codes.Canceled, + }, + }, + }, + { + desc: "custom RstCode", + setup: customSetup, + input: []http2.ErrCode{ + http2.ErrCode(1000), + http2.ErrCode(1001), + http2.ErrCode(1002), + http2.ErrCode(1003), + http2.ErrCode(1004), + http2.ErrCode(1005), + http2.ErrCode(1006), + http2.ErrCode(1007), + http2.ErrCode(1008), + http2.ErrCode(1009), + }, + want: []struct { + err error + stCode codes.Code + }{ + { + err: kerrors.ErrGracefulShutdown, + stCode: codes.Unavailable, + }, + { + err: kerrors.ErrBizCanceled, + stCode: codes.Canceled, + }, + { + err: kerrors.ErrStreamingCanceled, + stCode: codes.Canceled, + }, + { + err: kerrors.ErrStreamTimeout, + stCode: codes.DeadlineExceeded, + }, + { + err: gerrors.ErrMiddleHeader, + stCode: codes.Canceled, + }, + { + err: gerrors.ErrDecodeHeader, + stCode: codes.Canceled, + }, + { + err: gerrors.ErrHTTP2Stream, + stCode: codes.Canceled, + }, + { + err: gerrors.ErrClosedWithoutTrailer, + stCode: codes.Canceled, + }, + { + err: gerrors.ErrStreamFlowControl, + stCode: codes.Canceled, + }, + { + err: kerrors.ErrMetaSizeExceeded, + stCode: codes.Canceled, + }, + }, + clean: customClean, + }, + } + + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + if tc.setup != nil { + tc.setup(t) + } + if tc.clean != nil { + defer tc.clean(t) + } + for i, rstCode := range tc.input { + err, stCode := getMappingErrAndStatusCode(context.Background(), rstCode) + test.Assert(t, err == tc.want[i].err, err) + res := stCode == tc.want[i].stCode + if !res { + t.Logf("stCode: %d, err: %v", stCode, err) + } + test.Assert(t, stCode == tc.want[i].stCode, stCode) + } + }) + } +} + +func Test_getRstCode(t *testing.T) { + testcases := []struct { + desc string + input []error + want []http2.ErrCode + setup func(*testing.T) + clean func(*testing.T) + }{ + { + desc: "normal RstCode", + input: []error{ + nil, io.EOF, + status.New(codes.Internal, "test").WithMappingErr(kerrors.ErrGracefulShutdown).Err(), + }, + want: []http2.ErrCode{ + http2.ErrCodeNo, http2.ErrCodeNo, + http2.ErrCodeCancel, + }, + }, + { + desc: "custom RstCode", + input: []error{ + nil, io.EOF, + status.New(codes.Internal, "test").WithMappingErr(kerrors.ErrGracefulShutdown).Err(), + }, + want: []http2.ErrCode{ + http2.ErrCodeNo, http2.ErrCodeNo, + http2.ErrCode(1000), + }, + setup: customSetup, + clean: customClean, + }, + } + + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + if tc.setup != nil { + tc.setup(t) + } + if tc.clean != nil { + defer tc.clean(t) + } + for i, err := range tc.input { + rstCode := getRstCode(err) + test.Assert(t, rstCode == tc.want[i], rstCode) + } + }) + } +} + +func customSetup(t *testing.T) { + SetCustomRstCodeEnabled(true) + RegisterCustomRstCode(1000, kerrors.ErrGracefulShutdown) + RegisterCustomRstCode(1001, kerrors.ErrBizCanceled) + RegisterCustomRstCode(1002, kerrors.ErrStreamingCanceled) + RegisterCustomRstCode(1003, kerrors.ErrStreamTimeout) + RegisterCustomRstCode(1004, gerrors.ErrMiddleHeader) + RegisterCustomRstCode(1005, gerrors.ErrDecodeHeader) + RegisterCustomRstCode(1006, gerrors.ErrHTTP2Stream) + RegisterCustomRstCode(1007, gerrors.ErrClosedWithoutTrailer) + RegisterCustomRstCode(1008, gerrors.ErrStreamFlowControl) + RegisterCustomRstCode(1009, kerrors.ErrMetaSizeExceeded) +} + +func customClean(t *testing.T) { + SetCustomRstCodeEnabled(false) + cleanupCustomRstCodeMapping() +} diff --git a/pkg/remote/trans/nphttp2/grpc/stream_test.go b/pkg/remote/trans/nphttp2/grpc/stream_test.go new file mode 100644 index 0000000000..92dd9ab9c0 --- /dev/null +++ b/pkg/remote/trans/nphttp2/grpc/stream_test.go @@ -0,0 +1,36 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package grpc + +import ( + "testing" + + "github.com/cloudwego/kitex/internal/test" +) + +func TestStream_SetServiceMeta(t *testing.T) { + s := new(Stream) + test.Assert(t, s.getSourceService() == "remote service") + + testSvc := "test service" + err := s.SetServiceMeta(SourceServiceMetaKey, testSvc) + test.Assert(t, err == nil, err) + test.Assert(t, s.getSourceService() == "test service") + + err = s.SetServiceMeta("invalid", testSvc) + test.Assert(t, err != nil, err) +} diff --git a/pkg/remote/trans/nphttp2/grpc/transport.go b/pkg/remote/trans/nphttp2/grpc/transport.go index e2b5f0ef1f..81099624bf 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport.go +++ b/pkg/remote/trans/nphttp2/grpc/transport.go @@ -35,7 +35,9 @@ import ( "sync/atomic" "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" + gerrors "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/errors" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status" ) @@ -198,7 +200,7 @@ func (r *recvBufferReader) readClient(p []byte) (n int, err error) { // TODO: delaying ctx error seems like a unnecessary side effect. What // we really want is to mark the stream as done, and return ctx error // faster. - r.closeStream(ContextErr(r.ctx.Err())) + r.closeStream(clientContextErr(r.ctx.Err())) m := <-r.recv.get() return r.readAdditional(m, p) case m := <-r.recv.get(): @@ -236,7 +238,7 @@ type Stream struct { st ServerTransport // nil for client side Stream ct *http2Client // nil for server side Stream ctx context.Context // the associated context of the stream - cancel cancelWithReason // always nil for client side Stream + cancelFunc cancelWithReason // always nil for client side Stream done chan struct{} // closed at the end of stream to unblock writers. On the client side. ctxDone <-chan struct{} // same as done chan but for server side. Cache of ctx.Done() (for performance) method string // the associated RPC method of the stream @@ -285,6 +287,11 @@ type Stream struct { // contentSubtype is the content-subtype for requests. // this must be lowercase or the behavior is undefined. contentSubtype string + + // closeStreamErr is used to store the error when stream is closed + closeStreamErr atomic.Value + // sourceService is the source service name of this stream + sourceService atomic.Value } // isHeaderSent is only valid on the server-side. @@ -479,6 +486,60 @@ func (s *Stream) Read(p []byte) (n int, err error) { return io.ReadFull(s.trReader, p) } +func (s *Stream) getCloseStreamErr() error { + rawErr := s.closeStreamErr.Load() + if rawErr != nil { + return rawErr.(error) + } + return errStatusStreamDone +} + +type svcMetaKey string + +const ( + SourceServiceMetaKey svcMetaKey = "source_service" +) + +// SetServiceMeta is used to inject service-related metadata +func (s *Stream) SetServiceMeta(key svcMetaKey, val interface{}) error { + switch key { + case SourceServiceMetaKey: + svc, ok := val.(string) + if !ok { + return fmt.Errorf("%s expect string val, but got %v", SourceServiceMetaKey, val) + } + s.setSourceService(svc) + default: + return fmt.Errorf("unknown svcMetaKey: %s", key) + } + return nil +} + +func (s *Stream) setSourceService(svc string) { + s.sourceService.Store(svc) +} + +func (s *Stream) getSourceService() string { + rawSvc := s.sourceService.Load() + if rawSvc != nil { + return rawSvc.(string) + } + return "remote service" +} + +// err should be of the same type +func (s *Stream) cancel(err error) { + if err == nil { + return + } + if _, ok := err.(*status.Error); !ok { + klog.CtxWarnf(s.ctx, "stream canceled with non status.Error: %v", err) + err = status.New(codes.Canceled, err.Error()).WithMappingErr(kerrors.ErrStreamingCanceled).Err() + } + // all errors propagated by cancelFunc must be of type *status.Error or nil + s.cancelFunc(err) +} + // StreamWrite only used for unit test func StreamWrite(s *Stream, buffer *bytes.Buffer) { s.write(recvMsg{buffer: buffer}) @@ -496,10 +557,8 @@ func CreateStream(ctx context.Context, id uint32, requestRead func(i int), metho }, windowHandler: func(i int) {}, } - stream := &Stream{ id: id, - ctx: ctx, method: method, buf: recvBuffer, trReader: trReader, @@ -508,6 +567,9 @@ func CreateStream(ctx context.Context, id uint32, requestRead func(i int), metho hdrMu: sync.Mutex{}, } + ctx, cancel := context.WithCancel(ctx) + stream.ctx, stream.cancelFunc = newContextWithCancelReason(ctx, cancel) + return stream } @@ -762,20 +824,23 @@ func (e ConnectionError) Origin() error { var ( // ErrConnClosing indicates that the transport is closing. - ErrConnClosing = connectionErrorf(true, nil, "transport is closing") + ErrConnClosing = connectionErrorf(true, nil, "transport is closing") + errStatusConnClosing = status.New(codes.Unavailable, "transport is closing"). + WithMappingErr(gerrors.ErrConnectionIsClosing).Err() + errStatusControlBufFinished = status.New(codes.Unavailable, "controlbuf finished"). + WithMappingErr(gerrors.ErrControlBufFinished).Err() // errStreamDone is returned from write at the client side to indicate application // layer of an error. - errStreamDone = errors.New("the stream is done") + errStreamDone = errors.New("the stream is done") + errStatusStreamDone = status.New(codes.Internal, errStreamDone.Error()). + WithMappingErr(gerrors.ErrStreamIsDone).Err() - // errStreamDrain indicates that the stream is rejected because the + // errStatusStreamDrain indicates that the stream is rejected because the // connection is draining. This could be caused by goaway or balancer // removing the address. - errStreamDrain = status.New(codes.Unavailable, "the connection is draining").Err() - - // StatusGoAway indicates that the server sent a GOAWAY that included this - // stream's ID in unprocessed RPCs. - statusGoAway = status.New(codes.Unavailable, "the stream is rejected because server is draining the connection") + errStatusStreamDrain = status.New(codes.Unavailable, "the connection is draining"). + WithMappingErr(gerrors.ErrStreamDrain).Err() ) // GoAwayReason contains the reason for the GoAway frame received. @@ -796,15 +861,36 @@ const ( func ContextErr(err error) error { switch err { case context.DeadlineExceeded: - return status.New(codes.DeadlineExceeded, err.Error()).Err() + return status.New(codes.DeadlineExceeded, err.Error()).WithMappingErr(kerrors.ErrStreamTimeout).Err() case context.Canceled: - return status.New(codes.Canceled, err.Error()).Err() + return status.New(codes.Canceled, err.Error()).WithMappingErr(kerrors.ErrBizCanceled).Err() } statusErr, ok := err.(*status.Error) if ok { // only returned by contextWithCancelReason return statusErr } - return status.Errorf(codes.Internal, "Unexpected error from context packet: %v", err) + return status.Newf(codes.Internal, "Unexpected error from context packet: %v", err).WithMappingErr(kerrors.ErrStreamingCanceled).Err() +} + +// clientContextErr converts the error from context package into a status error +// when the error is passed through streams by cancel. +func clientContextErr(err error) error { + stErr := ContextErr(err).(*status.Error) + switch { + // errors defined here could pass through streams by cancel + // e.g. A -> B -> C + // stream between A and B closed by Graceful Shutdown, BC.Recv() could get kerrors.ErrGracefulShutdown + case errors.Is(stErr, kerrors.ErrStreamTimeout): + case errors.Is(stErr, kerrors.ErrBizCanceled): + case errors.Is(stErr, kerrors.ErrGracefulShutdown): + default: + // Other errs are treated as kerrors.ErrStreamingCanceled + // when passed through streams by cancel. + // Then users could use errors.Is(err, kerrors.ErrStreamingCanceled) + // to check if an exception in the upstream stream caused cancel to be delivered + return stErr.GRPCStatus().WithMappingErr(kerrors.ErrStreamingCanceled).Err() + } + return stErr } // IsStreamDoneErr returns true if the error indicates that the stream is done. @@ -840,3 +926,9 @@ func tlsAppendH2ToALPNProtocols(ps []string) []string { ret = append(ret, ps...) return append(ret, alpnProtoStrH2) } + +var ( + sendRSTStreamFrameSuffix = " [send RSTStream Frame]" + triggeredByRemoteServiceSuffix = " [triggered by remote service]" + triggeredByHandlerSideSuffix = " [triggered by handler side]" +) diff --git a/pkg/remote/trans/nphttp2/grpc/transport_test.go b/pkg/remote/trans/nphttp2/grpc/transport_test.go index ec98259a52..f5a3a5f271 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport_test.go +++ b/pkg/remote/trans/nphttp2/grpc/transport_test.go @@ -30,6 +30,7 @@ import ( "io" "math" "net" + "reflect" "runtime" "strconv" "strings" @@ -42,7 +43,9 @@ import ( "golang.org/x/net/http2/hpack" "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" + gerrors "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/errors" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc/grpcframe" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc/testutils" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status" @@ -822,8 +825,8 @@ func TestLargeMessageWithDelayRead(t *testing.T) { // return // } // ct.write(str, nil, nil, &Options{Last: true}) -// if _, err := str.Read(make([]byte, 8)); err != errStreamDrain && err != ErrConnClosing { -// t.Errorf("_.Read(_) = _, %v, want _, %v or %v", err, errStreamDrain, ErrConnClosing) +// if _, err := str.Read(make([]byte, 8)); err != errStatusStreamDrain && err != ErrConnClosing { +// t.Errorf("_.Read(_) = _, %v, want _, %v or %v", err, errStatusStreamDrain, ErrConnClosing) // } // }() // } @@ -858,11 +861,9 @@ func TestLargeMessageSuspension(t *testing.T) { msg := make([]byte, initialWindowSize*8) ct.Write(s, nil, msg, &Options{}) err = ct.Write(s, nil, msg, &Options{Last: true}) - if err != errStreamDone { - t.Fatalf("write got %v, want io.EOF", err) - } - expectedErr := status.Err(codes.DeadlineExceeded, context.DeadlineExceeded.Error()) - if _, err := s.Read(make([]byte, 8)); err.Error() != expectedErr.Error() { + test.Assert(t, errors.Is(err, ContextErr(ctx.Err())), err) + expectedErr := status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()).WithMappingErr(kerrors.ErrStreamTimeout) + if _, err := s.Read(make([]byte, 8)); reflect.DeepEqual(err, expectedErr) { t.Fatalf("Read got %v of type %T, want %v", err, err, expectedErr) } ct.Close(errSelfCloseForTest) @@ -892,7 +893,7 @@ func TestMaxStreams(t *testing.T) { pctx, cancel := context.WithCancel(context.Background()) defer cancel() timer := time.NewTimer(time.Second * 10) - expectedErr := status.Err(codes.DeadlineExceeded, context.DeadlineExceeded.Error()) + expectedErr := status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()).WithMappingErr(kerrors.ErrStreamTimeout) for { select { case <-timer.C: @@ -906,7 +907,7 @@ func TestMaxStreams(t *testing.T) { if str, err := ct.NewStream(ctx, callHdr); err == nil { slist = append(slist, str) continue - } else if err.Error() != expectedErr.Error() { + } else if reflect.DeepEqual(err, expectedErr) { t.Fatalf("ct.NewStream(_,_) = _, %v, want _, %v", err, expectedErr) } timer.Stop() @@ -994,8 +995,9 @@ func TestServerContextCanceledOnClosedConnection(t *testing.T) { ct.Close(errSelfCloseForTest) select { case <-ss.Context().Done(): - if ss.Context().Err() != errConnectionEOF { - t.Fatalf("ss.Context().Err() got %v, want %v", ss.Context().Err(), errConnectionEOF) + cErr := ss.Context().Err() + if !errors.Is(cErr, gerrors.ErrHTTP2Connection) { + t.Fatalf("ss.Context().Err() got %v, want %v", ss.Context().Err(), errStatusConnectionEOF) } case <-time.After(3 * time.Second): t.Fatalf("%s", "Failed to cancel the context of the sever side stream.") @@ -1531,8 +1533,8 @@ func TestContextErr(t *testing.T) { // outputs errOut error }{ - {context.DeadlineExceeded, status.Err(codes.DeadlineExceeded, context.DeadlineExceeded.Error())}, - {context.Canceled, status.Err(codes.Canceled, context.Canceled.Error())}, + {context.DeadlineExceeded, status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()).WithMappingErr(kerrors.ErrStreamTimeout).Err()}, + {context.Canceled, status.New(codes.Canceled, context.Canceled.Error()).WithMappingErr(kerrors.ErrBizCanceled).Err()}, } { err := ContextErr(test.errIn) if err.Error() != test.errOut.Error() { @@ -2050,3 +2052,44 @@ func TestTlsAppendH2ToALPNProtocols(t *testing.T) { appended = tlsAppendH2ToALPNProtocols(appended) test.Assert(t, len(appended) == 1) } + +func Test_clientContextErr(t *testing.T) { + testcases := []struct { + desc string + input error + target error + }{ + { + desc: "user invokes cancel()", + input: context.Canceled, + target: kerrors.ErrBizCanceled, + }, + { + desc: "ctx timeout", + input: context.DeadlineExceeded, + target: kerrors.ErrStreamTimeout, + }, + { + desc: "kerrors.ErrGracefulShutdown pass through", + input: status.New(codes.Internal, "pass through").WithMappingErr(kerrors.ErrGracefulShutdown).Err(), + target: kerrors.ErrGracefulShutdown, + }, + { + desc: "kerrors.ErrBizCanceled pass through", + input: status.New(codes.Internal, "pass through").WithMappingErr(kerrors.ErrBizCanceled).Err(), + target: kerrors.ErrBizCanceled, + }, + { + desc: "non-pass through", + input: status.New(codes.Internal, "non-pass through").WithMappingErr(gerrors.ErrStreamFlowControl).Err(), + target: kerrors.ErrStreamingCanceled, + }, + } + + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + err := clientContextErr(tc.input) + test.Assert(t, errors.Is(err, tc.target), err) + }) + } +} diff --git a/pkg/remote/trans/nphttp2/server_handler.go b/pkg/remote/trans/nphttp2/server_handler.go index f249f84242..f60a288f31 100644 --- a/pkg/remote/trans/nphttp2/server_handler.go +++ b/pkg/remote/trans/nphttp2/server_handler.go @@ -185,6 +185,7 @@ func (t *svrTransHandler) handleFunc(s *grpcTransport.Stream, svrTrans *SvrTrans return } } + s.SetServiceMeta(grpcTransport.SourceServiceMetaKey, ri.From().ServiceName()) rCtx = t.startTracer(rCtx, ri) defer func() { panicErr := recover() diff --git a/pkg/remote/trans/nphttp2/status/status.go b/pkg/remote/trans/nphttp2/status/status.go index 130d4425cd..09afb831be 100644 --- a/pkg/remote/trans/nphttp2/status/status.go +++ b/pkg/remote/trans/nphttp2/status/status.go @@ -49,6 +49,8 @@ type Iface interface { // and should be created with New, Newf, or FromProto. type Status struct { s *spb.Status + // kerr is the Kitex custom error that status maps to + kerr error } // New returns a Status representing c and msg. @@ -119,7 +121,7 @@ func (s *Status) Err() error { if s.Code() == codes.OK { return nil } - return &Error{e: s.Proto()} + return &Error{e: s.Proto(), kerr: s.kerr} } // WithDetails returns a new status with the provided details messages appended to the status. @@ -158,29 +160,54 @@ func (s *Status) Details() []interface{} { return details } +// WithMappingErr creates a new Status and injects Kitex mapping err +func (s *Status) WithMappingErr(kerr error) *Status { + return &Status{ + s: s.s, + kerr: kerr, + } +} + // Error wraps a pointer of a status proto. It implements error and Status, // and a nil *Error should never be returned by this package. type Error struct { e *spb.Status + // kerr is the Kitex custom error that status maps to + kerr error +} + +// GetMappingErr returns the Kitex custom error that status Error maps to +func (e *Error) GetMappingErr() error { + return e.kerr } func (e *Error) Error() string { - return fmt.Sprintf("rpc error: code = %d desc = %s", codes.Code(e.e.GetCode()), e.e.GetMessage()) + str := fmt.Sprintf("rpc error: code = %d desc = %s", codes.Code(e.e.GetCode()), e.e.GetMessage()) + if e.kerr == nil { + return str + } + return fmt.Sprintf("[%s] %s", e.kerr.Error(), str) } // GRPCStatus returns the Status represented by se. func (e *Error) GRPCStatus() *Status { - return FromProto(e.e) + st := FromProto(e.e) + st.kerr = e.kerr + return st } // Is implements future error.Is functionality. -// A Error is equivalent if the code and message are identical. +// A Error is equivalent if the code and message are identical +// or if the underlying mapped kitex error conforms to errors.Is. func (e *Error) Is(target error) bool { tse, ok := target.(*Error) - if !ok { - return false + if ok { + return proto.Equal(e.e, tse.e) + } + if e.kerr != nil { + return errors.Is(e.kerr, target) } - return proto.Equal(e.e, tse.e) + return false } // FromError returns a Status representing err if it was produced from this diff --git a/pkg/remote/trans/nphttp2/status/status_test.go b/pkg/remote/trans/nphttp2/status/status_test.go index 08cd55d82f..98b9c1d393 100644 --- a/pkg/remote/trans/nphttp2/status/status_test.go +++ b/pkg/remote/trans/nphttp2/status/status_test.go @@ -18,7 +18,10 @@ package status import ( "context" + "errors" "fmt" + "reflect" + "strings" "testing" spb "google.golang.org/genproto/googleapis/rpc/status" @@ -63,6 +66,22 @@ func TestStatus(t *testing.T) { statusNilErr, ok := FromError(nil) test.Assert(t, ok) test.Assert(t, statusNilErr == nil) + + mappingErr := errors.New("mappingErr") + oriSt := New(codes.Internal, "withMappingErr test").WithMappingErr(mappingErr) + rawStErr := oriSt.Err() + test.Assert(t, strings.Contains(rawStErr.Error(), mappingErr.Error()), rawStErr) + test.Assert(t, errors.Is(rawStErr, mappingErr), rawStErr) + stErr, ok := rawStErr.(*Error) + test.Assert(t, ok) + test.Assert(t, stErr.GetMappingErr() == mappingErr, stErr.GetMappingErr()) + st0 := stErr.GRPCStatus() + test.Assert(t, reflect.DeepEqual(st0, oriSt), st0) + st1, ok := FromError(rawStErr) + test.Assert(t, ok) + test.Assert(t, reflect.DeepEqual(st1, oriSt), st1) + st2 := Convert(rawStErr) + test.Assert(t, reflect.DeepEqual(st2, oriSt), st1) } func TestError(t *testing.T) { @@ -70,17 +89,19 @@ func TestError(t *testing.T) { s.Code = 1 s.Message = "test err" - er := &Error{s} + kerr := errors.New("kerr") + er := &Error{e: s, kerr: kerr} test.Assert(t, len(er.Error()) > 0) + test.Assert(t, strings.Contains(er.Error(), s.Message), er.Error()) + test.Assert(t, strings.Contains(er.Error(), kerr.Error()), er.Error()) status := er.GRPCStatus() test.Assert(t, status.Message() == s.Message) - is := er.Is(context.Canceled) - test.Assert(t, !is) + test.Assert(t, !er.Is(context.Canceled)) - is = er.Is(er) - test.Assert(t, is) + test.Assert(t, er.Is(er)) + test.Assert(t, er.Is(kerr)) } func TestFromContextError(t *testing.T) { @@ -101,7 +122,7 @@ func TestFromContextError(t *testing.T) { s := new(spb.Status) s.Code = 1 s.Message = "test err" - grpcErr := &Error{s} + grpcErr := &Error{e: s} // grpc err codeGrpcErr := Code(grpcErr) test.Assert(t, codeGrpcErr == codes.Canceled) diff --git a/pkg/streamx/provider/grpc/gerrors/gerrors.go b/pkg/streamx/provider/grpc/gerrors/gerrors.go new file mode 100644 index 0000000000..49015703bb --- /dev/null +++ b/pkg/streamx/provider/grpc/gerrors/gerrors.go @@ -0,0 +1,62 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2021 CloudWeGo Authors. + */ + +package gerrors + +import ( + grpc_errors "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/errors" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc" +) + +var ( + // grpc stream error + ErrHTTP2Stream = grpc_errors.ErrHTTP2Stream + ErrClosedWithoutTrailer = grpc_errors.ErrClosedWithoutTrailer + ErrMiddleHeader = grpc_errors.ErrMiddleHeader + ErrDecodeHeader = grpc_errors.ErrDecodeHeader + ErrRecvRstStream = grpc_errors.ErrRecvRstStream + ErrStreamDrain = grpc_errors.ErrStreamDrain + ErrStreamFlowControl = grpc_errors.ErrStreamFlowControl + ErrIllegalHeaderWrite = grpc_errors.ErrIllegalHeaderWrite + ErrStreamIsDone = grpc_errors.ErrStreamIsDone + ErrMaxStreamExceeded = grpc_errors.ErrMaxStreamExceeded + + // grpc connection error + ErrHTTP2Connection = grpc_errors.ErrHTTP2Connection + ErrEstablishConnection = grpc_errors.ErrEstablishConnection + ErrHandleGoAway = grpc_errors.ErrHandleGoAway + ErrKeepAlive = grpc_errors.ErrKeepAlive + ErrOperateHeaders = grpc_errors.ErrOperateHeaders + ErrNoActiveStream = grpc_errors.ErrNoActiveStream + ErrControlBufFinished = grpc_errors.ErrControlBufFinished + ErrNotReachable = grpc_errors.ErrNotReachable + ErrConnectionIsClosing = grpc_errors.ErrConnectionIsClosing +) + +// SetCustomRstCodeEnabled enables/disables the custom RstCode and err mapping. +// it is off by default. +var SetCustomRstCodeEnabled = grpc.SetCustomRstCodeEnabled + +// RegisterCustomRstCode registers a mapping between a custom RstCode and a mapping err. +// The mapping only works when SetCustomRstCodeEnabled(true) is invoked. +// e.g. RegisterCustomRstCode(1000, kerrors.ErrGracefulShutdown) +// When Kitex receives a RST_STREAM frame with error code 1000, it will inject kerrors.ErrGracefulShutdown into +// *status.Error that users will conceive +var RegisterCustomRstCode = grpc.RegisterCustomRstCode