From 868302107425e8962e0cf19464ad7563486c886d Mon Sep 17 00:00:00 2001 From: "yuxuan.wang1" Date: Mon, 16 Sep 2024 15:37:20 +0800 Subject: [PATCH] return real error for client-side streaming.Send --- pkg/remote/trans/nphttp2/grpc/http2_client.go | 16 ++++++++++++---- pkg/remote/trans/nphttp2/grpc/transport.go | 14 +++++++++++++- pkg/remote/trans/nphttp2/grpc/transport_test.go | 4 +--- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/pkg/remote/trans/nphttp2/grpc/http2_client.go b/pkg/remote/trans/nphttp2/grpc/http2_client.go index 7206ddfa0c..d29f3eda2d 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_client.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_client.go @@ -557,6 +557,15 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2. // This will unblock reads eventually. s.write(recvMsg{err: err}) } + + // store closeStreamErr + if err == io.EOF { + err = st.Err() + } + if err != nil { + s.closeStreamErr.Store(err) + } + // If headerChan isn't closed, then close it. if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) { s.noHeaders = true @@ -598,7 +607,6 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2. // addrConn level and blocks until the addrConn is successfully connected. func (t *http2Client) Close(err error) error { if rawErr, ok := err.(ConnectionError); ok { - err = rawErr err = status.Err(codes.Unavailable, rawErr.Desc) } t.mu.Lock() @@ -660,10 +668,10 @@ func (t *http2Client) Write(s *Stream, hdr, data []byte, opts *Options) error { if opts.Last { // If it's the last message, update stream state. if !s.compareAndSwapState(streamActive, streamWriteDone) { - return errStreamDone + return s.getCloseStreamErr() } } else if s.getState() != streamActive { - return errStreamDone + return s.getCloseStreamErr() } df := newDataFrame() df.streamID = s.id @@ -674,7 +682,7 @@ func (t *http2Client) Write(s *Stream, hdr, data []byte, opts *Options) error { df.originD = df.d if hdr != nil || data != nil { // If it's not an empty data frame, check quota. if err := s.wq.get(int32(len(hdr) + len(data))); err != nil { - return err + return s.getCloseStreamErr() } } return t.controlBuf.put(df) diff --git a/pkg/remote/trans/nphttp2/grpc/transport.go b/pkg/remote/trans/nphttp2/grpc/transport.go index 46e8272255..b17b0635fa 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport.go +++ b/pkg/remote/trans/nphttp2/grpc/transport.go @@ -285,6 +285,9 @@ type Stream struct { // contentSubtype is the content-subtype for requests. // this must be lowercase or the behavior is undefined. contentSubtype string + + // closeStreamErr is used to store the error when stream is closed + closeStreamErr atomic.Value } // isHeaderSent is only valid on the server-side. @@ -479,6 +482,14 @@ func (s *Stream) Read(p []byte) (n int, err error) { return io.ReadFull(s.trReader, p) } +func (s *Stream) getCloseStreamErr() error { + rawErr := s.closeStreamErr.Load() + if rawErr != nil { + return rawErr.(error) + } + return errStatusStreamDone +} + // StreamWrite only used for unit test func StreamWrite(s *Stream, buffer *bytes.Buffer) { s.write(recvMsg{buffer: buffer}) @@ -768,7 +779,8 @@ var ( // errStreamDone is returned from write at the client side to indicate application // layer of an error. - errStreamDone = errors.New("the stream is done") + errStreamDone = errors.New("the stream is done") + errStatusStreamDone = status.Err(codes.Internal, errStreamDone.Error()) // errStreamDrain indicates that the stream is rejected because the // connection is draining. This could be caused by goaway or balancer diff --git a/pkg/remote/trans/nphttp2/grpc/transport_test.go b/pkg/remote/trans/nphttp2/grpc/transport_test.go index 9d2b2e698d..895a22ed80 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport_test.go +++ b/pkg/remote/trans/nphttp2/grpc/transport_test.go @@ -947,9 +947,7 @@ func TestLargeMessageSuspension(t *testing.T) { msg := make([]byte, initialWindowSize*8) ct.Write(s, nil, msg, &Options{}) err = ct.Write(s, nil, msg, &Options{Last: true}) - if err != errStreamDone { - t.Fatalf("write got %v, want io.EOF", err) - } + test.Assert(t, errors.Is(err, ContextErr(ctx.Err())), err) expectedErr := status.Err(codes.DeadlineExceeded, context.DeadlineExceeded.Error()) if _, err := s.Read(make([]byte, 8)); err.Error() != expectedErr.Error() { t.Fatalf("Read got %v of type %T, want %v", err, err, expectedErr)