Skip to content

Commit

Permalink
feat: Supports callbacks when reading a message fails
Browse files Browse the repository at this point in the history
  • Loading branch information
tttoad committed Dec 27, 2024
1 parent b7beae5 commit 2ed37ef
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 8 deletions.
14 changes: 11 additions & 3 deletions server/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ func (c CallbacksStruct) OnConnecting(request *http.Request) types.ConnectionRes
// ConnectionCallbacksStruct is a struct that implements ConnectionCallbacks interface and allows
// to override only the methods that are needed.
type ConnectionCallbacksStruct struct {
OnConnectedFunc func(ctx context.Context, conn types.Connection)
OnMessageFunc func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent
OnConnectionCloseFunc func(conn types.Connection)
OnConnectedFunc func(ctx context.Context, conn types.Connection)
OnMessageFunc func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent
OnConnectionCloseFunc func(conn types.Connection)
OnReadMessageErrorFunc func(conn types.Connection, mt int, msgByte []byte, err error)
}

var _ types.ConnectionCallbacks = (*ConnectionCallbacksStruct)(nil)
Expand Down Expand Up @@ -61,3 +62,10 @@ func (c ConnectionCallbacksStruct) OnConnectionClose(conn types.Connection) {
c.OnConnectionCloseFunc(conn)
}
}

// OnReadMessageError implements types.ConnectionCallbacks.
func (c ConnectionCallbacksStruct) OnReadMessageError(conn types.Connection, mt int, msgByte []byte, err error) {
if c.OnReadMessageErrorFunc != nil {
c.OnReadMessageErrorFunc(conn, mt, msgByte, err)
}
}
51 changes: 46 additions & 5 deletions server/serverimpl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,52 @@ func TestServerReceiveSendMessage(t *testing.T) {
assert.EqualValues(t, settings.CustomCapabilities, response.CustomCapabilities.Capabilities)
}

func TestServerReceiveSendErrorMessage(t *testing.T) {
var rcvMsg atomic.Value
type ErrorInfo struct {
mt int
msgByte []byte
err error
}
callbacks := CallbacksStruct{
OnConnectingFunc: func(request *http.Request) types.ConnectionResponse {
return types.ConnectionResponse{Accept: true, ConnectionCallbacks: ConnectionCallbacksStruct{
OnReadMessageErrorFunc: func(conn types.Connection, mt int, msgByte []byte, err error) {
rcvMsg.Store(ErrorInfo{
mt: mt,
msgByte: msgByte,
err: err,
})
},
}}
},
}

// Start a Server.
settings := &StartSettings{Settings: Settings{
Callbacks: callbacks,
CustomCapabilities: []string{"local.test.capability"},
}}
srv := startServer(t, settings)
defer srv.Stop(context.Background())

// Connect using a WebSocket client.
conn, _, _ := dialClient(settings)
require.NotNil(t, conn)
defer conn.Close()

// Send a message to the Server.
err := conn.WriteMessage(websocket.TextMessage, []byte(""))
require.NoError(t, err)

// Wait until Server receives the message.
eventually(t, func() bool { return rcvMsg.Load() != nil })
errInfo := rcvMsg.Load().(ErrorInfo)
assert.EqualValues(t, websocket.TextMessage, errInfo.mt)
assert.EqualValues(t, []byte(""), errInfo.msgByte)
assert.NotNil(t, errInfo.err)
}

func TestServerReceiveSendMessageWithCompression(t *testing.T) {
// Use highly compressible config body.
uncompressedCfg := []byte(strings.Repeat("test", 10000))
Expand Down Expand Up @@ -620,7 +666,6 @@ func TestServerAttachSendMessagePlainHTTP(t *testing.T) {
}

func TestServerHonoursClientRequestContentEncoding(t *testing.T) {

hc := http.Client{}
var rcvMsg atomic.Value
var onConnectedCalled, onCloseCalled int32
Expand Down Expand Up @@ -698,7 +743,6 @@ func TestServerHonoursClientRequestContentEncoding(t *testing.T) {
}

func TestServerHonoursAcceptEncoding(t *testing.T) {

hc := http.Client{}
var rcvMsg atomic.Value
var onConnectedCalled, onCloseCalled int32
Expand Down Expand Up @@ -985,7 +1029,6 @@ func BenchmarkSendToClient(b *testing.B) {
}
srv := New(&sharedinternal.NopLogger{})
err := srv.Start(*settings)

if err != nil {
b.Error(err)
}
Expand Down Expand Up @@ -1017,7 +1060,6 @@ func BenchmarkSendToClient(b *testing.B) {

for _, conn := range serverConnections {
err := conn.Send(context.Background(), &protobufs.ServerToAgent{})

if err != nil {
b.Error(err)
}
Expand All @@ -1026,5 +1068,4 @@ func BenchmarkSendToClient(b *testing.B) {
for _, conn := range clientConnections {
conn.Close()
}

}
3 changes: 3 additions & 0 deletions server/types/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,7 @@ type ConnectionCallbacks interface {

// OnConnectionClose is called when the OpAMP connection is closed.
OnConnectionClose(conn Connection)

// OnConnectionError is called when an error occurs while reading or serializing a message.
OnReadMessageError(conn Connection, mt int, msgByte []byte, err error)
}

0 comments on commit 2ed37ef

Please sign in to comment.