diff --git a/internal/examples/server/opampsrv/opampsrv.go b/internal/examples/server/opampsrv/opampsrv.go index ea19d912..83a0bbd2 100644 --- a/internal/examples/server/opampsrv/opampsrv.go +++ b/internal/examples/server/opampsrv/opampsrv.go @@ -44,13 +44,13 @@ func NewServer(agents *data.Agents) *Server { func (srv *Server) Start() { settings := server.StartSettings{ Settings: server.Settings{ - Callbacks: server.CallbacksStruct{ - OnConnectingFunc: func(request *http.Request) types.ConnectionResponse { + Callbacks: types.Callbacks{ + OnConnecting: func(request *http.Request) types.ConnectionResponse { return types.ConnectionResponse{ Accept: true, - ConnectionCallbacks: server.ConnectionCallbacksStruct{ - OnMessageFunc: srv.onMessage, - OnConnectionCloseFunc: srv.onDisconnect, + ConnectionCallbacks: types.ConnectionCallbacks{ + OnMessage: srv.onMessage, + OnConnectionClose: srv.onDisconnect, }, } }, diff --git a/server/callbacks.go b/server/callbacks.go deleted file mode 100644 index c3e75a96..00000000 --- a/server/callbacks.go +++ /dev/null @@ -1,63 +0,0 @@ -package server - -import ( - "context" - "net/http" - - "github.com/open-telemetry/opamp-go/protobufs" - "github.com/open-telemetry/opamp-go/server/types" -) - -// CallbacksStruct is a struct that implements Callbacks interface and allows -// to override only the methods that are needed. If a method is not overridden then it will -// accept all connections. -type CallbacksStruct struct { - OnConnectingFunc func(request *http.Request) types.ConnectionResponse -} - -var _ types.Callbacks = (*CallbacksStruct)(nil) - -// OnConnecting implements Callbacks.interface. -func (c CallbacksStruct) OnConnecting(request *http.Request) types.ConnectionResponse { - if c.OnConnectingFunc != nil { - return c.OnConnectingFunc(request) - } - return types.ConnectionResponse{Accept: true} -} - -// ConnectionCallbacksStruct is a struct that implements ConnectionCallbacks interface and allows -// to override only the methods that are needed. -type ConnectionCallbacksStruct struct { - OnConnectedFunc func(ctx context.Context, conn types.Connection) - OnMessageFunc func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent - OnConnectionCloseFunc func(conn types.Connection) -} - -var _ types.ConnectionCallbacks = (*ConnectionCallbacksStruct)(nil) - -// OnConnected implements ConnectionCallbacks.OnConnected. -func (c ConnectionCallbacksStruct) OnConnected(ctx context.Context, conn types.Connection) { - if c.OnConnectedFunc != nil { - c.OnConnectedFunc(ctx, conn) - } -} - -// OnMessage implements ConnectionCallbacks.OnMessage. -// If OnMessageFunc is nil then it will send an empty response to the agent -func (c ConnectionCallbacksStruct) OnMessage(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent { - if c.OnMessageFunc != nil { - return c.OnMessageFunc(ctx, conn, message) - } else { - // We will send an empty response since there is no user-defined callback to handle it. - return &protobufs.ServerToAgent{ - InstanceUid: message.InstanceUid, - } - } -} - -// OnConnectionClose implements ConnectionCallbacks.OnConnectionClose. -func (c ConnectionCallbacksStruct) OnConnectionClose(conn types.Connection) { - if c.OnConnectionCloseFunc != nil { - c.OnConnectionCloseFunc(conn) - } -} diff --git a/server/serverimpl.go b/server/serverimpl.go index 815c0528..dcdd044c 100644 --- a/server/serverimpl.go +++ b/server/serverimpl.go @@ -68,6 +68,7 @@ func New(logger types.Logger) *server { func (s *server) Attach(settings Settings) (HTTPHandlerFunc, ConnContext, error) { s.settings = settings + s.settings.Callbacks.SetDefaults() s.wsUpgrader = websocket.Upgrader{ EnableCompression: settings.EnableCompression, } @@ -169,26 +170,25 @@ func (s *server) Addr() net.Addr { func (s *server) httpHandler(w http.ResponseWriter, req *http.Request) { var connectionCallbacks serverTypes.ConnectionCallbacks - if s.settings.Callbacks != nil { - resp := s.settings.Callbacks.OnConnecting(req) - if !resp.Accept { - // HTTP connection is not accepted. Set the response headers. - for k, v := range resp.HTTPResponseHeader { - w.Header().Set(k, v) - } - // And write the response status code. - w.WriteHeader(resp.HTTPStatusCode) - return + resp := s.settings.Callbacks.OnConnecting(req) + if !resp.Accept { + // HTTP connection is not accepted. Set the response headers. + for k, v := range resp.HTTPResponseHeader { + w.Header().Set(k, v) } - // use connection-specific handler provided by ConnectionResponse - connectionCallbacks = resp.ConnectionCallbacks + // And write the response status code. + w.WriteHeader(resp.HTTPStatusCode) + return } + // use connection-specific handler provided by ConnectionResponse + connectionCallbacks = resp.ConnectionCallbacks + connectionCallbacks.SetDefaults() // HTTP connection is accepted. Check if it is a plain HTTP request. if req.Header.Get(headerContentType) == contentTypeProtobuf { // Yes, a plain HTTP request. - s.handlePlainHTTPRequest(req, w, connectionCallbacks) + s.handlePlainHTTPRequest(req, w, &connectionCallbacks) return } @@ -201,10 +201,10 @@ func (s *server) httpHandler(w http.ResponseWriter, req *http.Request) { // Return from this func to reduce memory usage. // Handle the connection on a separate goroutine. - go s.handleWSConnection(req.Context(), conn, connectionCallbacks) + go s.handleWSConnection(req.Context(), conn, &connectionCallbacks) } -func (s *server) handleWSConnection(reqCtx context.Context, wsConn *websocket.Conn, connectionCallbacks serverTypes.ConnectionCallbacks) { +func (s *server) handleWSConnection(reqCtx context.Context, wsConn *websocket.Conn, connectionCallbacks *serverTypes.ConnectionCallbacks) { agentConn := wsConnection{wsConn: wsConn, connMutex: &sync.Mutex{}} defer func() { @@ -216,14 +216,10 @@ func (s *server) handleWSConnection(reqCtx context.Context, wsConn *websocket.Co } }() - if connectionCallbacks != nil { - connectionCallbacks.OnConnectionClose(agentConn) - } + connectionCallbacks.OnConnectionClose(agentConn) }() - if connectionCallbacks != nil { - connectionCallbacks.OnConnected(reqCtx, agentConn) - } + connectionCallbacks.OnConnected(reqCtx, agentConn) sentCustomCapabilities := false @@ -254,21 +250,19 @@ func (s *server) handleWSConnection(reqCtx context.Context, wsConn *websocket.Co continue } - if connectionCallbacks != nil { - response := connectionCallbacks.OnMessage(msgContext, agentConn, &request) - if len(response.InstanceUid) == 0 { - response.InstanceUid = request.InstanceUid - } - if !sentCustomCapabilities { - response.CustomCapabilities = &protobufs.CustomCapabilities{ - Capabilities: s.settings.CustomCapabilities, - } - sentCustomCapabilities = true - } - err = agentConn.Send(msgContext, response) - if err != nil { - s.logger.Errorf(msgContext, "Cannot send message to WebSocket: %v", err) + response := connectionCallbacks.OnMessage(msgContext, agentConn, &request) + if len(response.InstanceUid) == 0 { + response.InstanceUid = request.InstanceUid + } + if !sentCustomCapabilities { + response.CustomCapabilities = &protobufs.CustomCapabilities{ + Capabilities: s.settings.CustomCapabilities, } + sentCustomCapabilities = true + } + err = agentConn.Send(msgContext, response) + if err != nil { + s.logger.Errorf(msgContext, "Cannot send message to WebSocket: %v", err) } } } @@ -310,7 +304,7 @@ func compressGzip(data []byte) ([]byte, error) { return buf.Bytes(), nil } -func (s *server) handlePlainHTTPRequest(req *http.Request, w http.ResponseWriter, connectionCallbacks serverTypes.ConnectionCallbacks) { +func (s *server) handlePlainHTTPRequest(req *http.Request, w http.ResponseWriter, connectionCallbacks *serverTypes.ConnectionCallbacks) { bodyBytes, err := s.readReqBody(req) if err != nil { s.logger.Debugf(req.Context(), "Cannot read HTTP body: %v", err) @@ -331,11 +325,6 @@ func (s *server) handlePlainHTTPRequest(req *http.Request, w http.ResponseWriter conn: connFromRequest(req), } - if connectionCallbacks == nil { - w.WriteHeader(http.StatusInternalServerError) - return - } - connectionCallbacks.OnConnected(req.Context(), agentConn) defer func() { diff --git a/server/serverimpl_test.go b/server/serverimpl_test.go index d7a9098e..ca7a5c33 100644 --- a/server/serverimpl_test.go +++ b/server/serverimpl_test.go @@ -127,8 +127,8 @@ func TestServerAddrWithZeroPort(t *testing.T) { } func TestServerStartRejectConnection(t *testing.T) { - callbacks := CallbacksStruct{ - OnConnectingFunc: func(request *http.Request) types.ConnectionResponse { + callbacks := types.Callbacks{ + OnConnecting: func(request *http.Request) types.ConnectionResponse { // Reject the incoming HTTP connection. return types.ConnectionResponse{ Accept: false, @@ -161,14 +161,14 @@ func TestServerStartAcceptConnection(t *testing.T) { connectedCalled := int32(0) connectionCloseCalled := int32(0) var srvConn types.Connection - callbacks := CallbacksStruct{ - OnConnectingFunc: func(request *http.Request) types.ConnectionResponse { - return types.ConnectionResponse{Accept: true, ConnectionCallbacks: ConnectionCallbacksStruct{ - OnConnectedFunc: func(ctx context.Context, conn types.Connection) { + callbacks := types.Callbacks{ + OnConnecting: func(request *http.Request) types.ConnectionResponse { + return types.ConnectionResponse{Accept: true, ConnectionCallbacks: types.ConnectionCallbacks{ + OnConnected: func(ctx context.Context, conn types.Connection) { srvConn = conn atomic.StoreInt32(&connectedCalled, 1) }, - OnConnectionCloseFunc: func(conn types.Connection) { + OnConnectionClose: func(conn types.Connection) { atomic.StoreInt32(&connectionCloseCalled, 1) assert.EqualValues(t, srvConn, conn) }, @@ -211,10 +211,10 @@ func TestDisconnectHttpConnection(t *testing.T) { func TestDisconnectWSConnection(t *testing.T) { connectionCloseCalled := int32(0) - callback := CallbacksStruct{ - OnConnectingFunc: func(request *http.Request) types.ConnectionResponse { - return types.ConnectionResponse{Accept: true, ConnectionCallbacks: ConnectionCallbacksStruct{ - OnConnectionCloseFunc: func(conn types.Connection) { + callback := types.Callbacks{ + OnConnecting: func(request *http.Request) types.ConnectionResponse { + return types.ConnectionResponse{Accept: true, ConnectionCallbacks: types.ConnectionCallbacks{ + OnConnectionClose: func(conn types.Connection) { atomic.StoreInt32(&connectionCloseCalled, 1) }, }} @@ -251,10 +251,10 @@ var testInstanceUid = []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6} func TestServerReceiveSendMessage(t *testing.T) { var rcvMsg atomic.Value - callbacks := CallbacksStruct{ - OnConnectingFunc: func(request *http.Request) types.ConnectionResponse { - return types.ConnectionResponse{Accept: true, ConnectionCallbacks: ConnectionCallbacksStruct{ - OnMessageFunc: func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent { + callbacks := types.Callbacks{ + OnConnecting: func(request *http.Request) types.ConnectionResponse { + return types.ConnectionResponse{Accept: true, ConnectionCallbacks: types.ConnectionCallbacks{ + OnMessage: func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent { // Remember received message. rcvMsg.Store(message) @@ -322,10 +322,10 @@ func TestServerReceiveSendMessageWithCompression(t *testing.T) { for _, withCompression := range tests { t.Run(fmt.Sprintf("%v", withCompression), func(t *testing.T) { var rcvMsg atomic.Value - callbacks := CallbacksStruct{ - OnConnectingFunc: func(request *http.Request) types.ConnectionResponse { - return types.ConnectionResponse{Accept: true, ConnectionCallbacks: ConnectionCallbacksStruct{ - OnMessageFunc: func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent { + callbacks := types.Callbacks{ + OnConnecting: func(request *http.Request) types.ConnectionResponse { + return types.ConnectionResponse{Accept: true, ConnectionCallbacks: types.ConnectionCallbacks{ + OnMessage: func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent { // Remember received message. rcvMsg.Store(message) @@ -420,13 +420,13 @@ func TestServerReceiveSendMessageWithCompression(t *testing.T) { func TestServerReceiveSendMessagePlainHTTP(t *testing.T) { var rcvMsg atomic.Value var onConnectedCalled, onCloseCalled int32 - callbacks := CallbacksStruct{ - OnConnectingFunc: func(request *http.Request) types.ConnectionResponse { - return types.ConnectionResponse{Accept: true, ConnectionCallbacks: ConnectionCallbacksStruct{ - OnConnectedFunc: func(ctx context.Context, conn types.Connection) { + callbacks := types.Callbacks{ + OnConnecting: func(request *http.Request) types.ConnectionResponse { + return types.ConnectionResponse{Accept: true, ConnectionCallbacks: types.ConnectionCallbacks{ + OnConnected: func(ctx context.Context, conn types.Connection) { atomic.StoreInt32(&onConnectedCalled, 1) }, - OnMessageFunc: func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent { + OnMessage: func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent { // Remember received message. rcvMsg.Store(message) @@ -437,7 +437,7 @@ func TestServerReceiveSendMessagePlainHTTP(t *testing.T) { } return &response }, - OnConnectionCloseFunc: func(conn types.Connection) { + OnConnectionClose: func(conn types.Connection) { atomic.StoreInt32(&onCloseCalled, 1) }, }} @@ -492,14 +492,14 @@ func TestServerAttachAcceptConnection(t *testing.T) { connectedCalled := int32(0) connectionCloseCalled := int32(0) var srvConn types.Connection - callbacks := CallbacksStruct{ - OnConnectingFunc: func(request *http.Request) types.ConnectionResponse { - return types.ConnectionResponse{Accept: true, ConnectionCallbacks: ConnectionCallbacksStruct{ - OnConnectedFunc: func(ctx context.Context, conn types.Connection) { + callbacks := types.Callbacks{ + OnConnecting: func(request *http.Request) types.ConnectionResponse { + return types.ConnectionResponse{Accept: true, ConnectionCallbacks: types.ConnectionCallbacks{ + OnConnected: func(ctx context.Context, conn types.Connection) { atomic.StoreInt32(&connectedCalled, 1) srvConn = conn }, - OnConnectionCloseFunc: func(conn types.Connection) { + OnConnectionClose: func(conn types.Connection) { atomic.StoreInt32(&connectionCloseCalled, 1) assert.EqualValues(t, srvConn, conn) }, @@ -542,14 +542,14 @@ func TestServerAttachSendMessagePlainHTTP(t *testing.T) { var rcvMsg atomic.Value var srvConn types.Connection - callbacks := CallbacksStruct{ - OnConnectingFunc: func(request *http.Request) types.ConnectionResponse { - return types.ConnectionResponse{Accept: true, ConnectionCallbacks: ConnectionCallbacksStruct{ - OnConnectedFunc: func(ctx context.Context, conn types.Connection) { + callbacks := types.Callbacks{ + OnConnecting: func(request *http.Request) types.ConnectionResponse { + return types.ConnectionResponse{Accept: true, ConnectionCallbacks: types.ConnectionCallbacks{ + OnConnected: func(ctx context.Context, conn types.Connection) { atomic.StoreInt32(&connectedCalled, 1) srvConn = conn }, - OnMessageFunc: func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent { + OnMessage: func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent { // Remember received message. rcvMsg.Store(message) @@ -560,7 +560,7 @@ func TestServerAttachSendMessagePlainHTTP(t *testing.T) { } return &response }, - OnConnectionCloseFunc: func(conn types.Connection) { + OnConnectionClose: func(conn types.Connection) { atomic.StoreInt32(&connectionCloseCalled, 1) assert.EqualValues(t, srvConn, conn) }, @@ -624,13 +624,13 @@ func TestServerHonoursClientRequestContentEncoding(t *testing.T) { hc := http.Client{} var rcvMsg atomic.Value var onConnectedCalled, onCloseCalled int32 - callbacks := CallbacksStruct{ - OnConnectingFunc: func(request *http.Request) types.ConnectionResponse { - return types.ConnectionResponse{Accept: true, ConnectionCallbacks: ConnectionCallbacksStruct{ - OnConnectedFunc: func(ctx context.Context, conn types.Connection) { + callbacks := types.Callbacks{ + OnConnecting: func(request *http.Request) types.ConnectionResponse { + return types.ConnectionResponse{Accept: true, ConnectionCallbacks: types.ConnectionCallbacks{ + OnConnected: func(ctx context.Context, conn types.Connection) { atomic.StoreInt32(&onConnectedCalled, 1) }, - OnMessageFunc: func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent { + OnMessage: func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent { // Remember received message. rcvMsg.Store(message) @@ -641,7 +641,7 @@ func TestServerHonoursClientRequestContentEncoding(t *testing.T) { } return &response }, - OnConnectionCloseFunc: func(conn types.Connection) { + OnConnectionClose: func(conn types.Connection) { atomic.StoreInt32(&onCloseCalled, 1) }, }} @@ -702,13 +702,13 @@ func TestServerHonoursAcceptEncoding(t *testing.T) { hc := http.Client{} var rcvMsg atomic.Value var onConnectedCalled, onCloseCalled int32 - callbacks := CallbacksStruct{ - OnConnectingFunc: func(request *http.Request) types.ConnectionResponse { - return types.ConnectionResponse{Accept: true, ConnectionCallbacks: ConnectionCallbacksStruct{ - OnConnectedFunc: func(ctx context.Context, conn types.Connection) { + callbacks := types.Callbacks{ + OnConnecting: func(request *http.Request) types.ConnectionResponse { + return types.ConnectionResponse{Accept: true, ConnectionCallbacks: types.ConnectionCallbacks{ + OnConnected: func(ctx context.Context, conn types.Connection) { atomic.StoreInt32(&onConnectedCalled, 1) }, - OnMessageFunc: func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent { + OnMessage: func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent { // Remember received message. rcvMsg.Store(message) @@ -719,7 +719,7 @@ func TestServerHonoursAcceptEncoding(t *testing.T) { } return &response }, - OnConnectionCloseFunc: func(conn types.Connection) { + OnConnectionClose: func(conn types.Connection) { atomic.StoreInt32(&onCloseCalled, 1) }, }} @@ -809,10 +809,10 @@ func TestDecodeMessage(t *testing.T) { func TestConnectionAllowsConcurrentWrites(t *testing.T) { ch := make(chan struct{}) srvConnVal := atomic.Value{} - callbacks := CallbacksStruct{ - OnConnectingFunc: func(request *http.Request) types.ConnectionResponse { - return types.ConnectionResponse{Accept: true, ConnectionCallbacks: ConnectionCallbacksStruct{ - OnConnectedFunc: func(ctx context.Context, conn types.Connection) { + callbacks := types.Callbacks{ + OnConnecting: func(request *http.Request) types.ConnectionResponse { + return types.ConnectionResponse{Accept: true, ConnectionCallbacks: types.ConnectionCallbacks{ + OnConnected: func(ctx context.Context, conn types.Connection) { srvConnVal.Store(conn) ch <- struct{}{} }, @@ -870,11 +870,11 @@ func TestServerCallsHTTPMiddlewareOverWebsocket(t *testing.T) { ) } - callbacks := CallbacksStruct{ - OnConnectingFunc: func(request *http.Request) types.ConnectionResponse { + callbacks := types.Callbacks{ + OnConnecting: func(request *http.Request) types.ConnectionResponse { return types.ConnectionResponse{ Accept: true, - ConnectionCallbacks: ConnectionCallbacksStruct{}, + ConnectionCallbacks: types.ConnectionCallbacks{}, } }, } @@ -914,11 +914,11 @@ func TestServerCallsHTTPMiddlewareOverHTTP(t *testing.T) { ) } - callbacks := CallbacksStruct{ - OnConnectingFunc: func(request *http.Request) types.ConnectionResponse { + callbacks := types.Callbacks{ + OnConnecting: func(request *http.Request) types.ConnectionResponse { return types.ConnectionResponse{ Accept: true, - ConnectionCallbacks: ConnectionCallbacksStruct{}, + ConnectionCallbacks: types.ConnectionCallbacks{}, } }, } @@ -965,10 +965,10 @@ func BenchmarkSendToClient(b *testing.B) { clientConnections := []*websocket.Conn{} serverConnections := []types.Connection{} srvConnectionsMutex := sync.Mutex{} - callbacks := CallbacksStruct{ - OnConnectingFunc: func(request *http.Request) types.ConnectionResponse { - return types.ConnectionResponse{Accept: true, ConnectionCallbacks: ConnectionCallbacksStruct{ - OnConnectedFunc: func(ctx context.Context, conn types.Connection) { + callbacks := types.Callbacks{ + OnConnecting: func(request *http.Request) types.ConnectionResponse { + return types.ConnectionResponse{Accept: true, ConnectionCallbacks: types.ConnectionCallbacks{ + OnConnected: func(ctx context.Context, conn types.Connection) { srvConnectionsMutex.Lock() serverConnections = append(serverConnections, conn) srvConnectionsMutex.Unlock() diff --git a/server/types/callbacks.go b/server/types/callbacks.go index 0546903d..b25e8181 100644 --- a/server/types/callbacks.go +++ b/server/types/callbacks.go @@ -15,12 +15,12 @@ type ConnectionResponse struct { ConnectionCallbacks ConnectionCallbacks } -type Callbacks interface { +type Callbacks struct { // OnConnecting is called when there is a new incoming connection. // The handler can examine the request and either accept or reject the connection. // To accept: // Return ConnectionResponse with Accept=true. ConnectionCallbacks MUST be set to an - // implementation of the ConnectionCallbacks interface to handle the connection. + // instance of the ConnectionCallbacks struct to handle the connection callbacks. // HTTPStatusCode and HTTPResponseHeader are ignored. // // To reject: @@ -28,28 +28,65 @@ type Callbacks interface { // non-zero value to indicate the rejection reason (typically 401, 429 or 503). // HTTPResponseHeader may be optionally set (e.g. "Retry-After: 30"). // ConnectionCallbacks is ignored. - OnConnecting(request *http.Request) ConnectionResponse + OnConnecting func(request *http.Request) ConnectionResponse } -// ConnectionCallbacks receives callbacks for a specific connection. An implementation of -// this interface MUST be set on the ConnectionResponse returned by the OnConnecting -// callback if Accept=true. The implementation can be shared by all connections or can be +func defaultOnConnecting(r *http.Request) ConnectionResponse { + return ConnectionResponse{Accept: true} +} + +func (c *Callbacks) SetDefaults() { + if c.OnConnecting == nil { + c.OnConnecting = defaultOnConnecting + } +} + +// ConnectionCallbacks specifies callbacks for a specific connection. An instance of +// this struct MUST be set on the ConnectionResponse returned by the OnConnecting +// callback if Accept=true. The instance can be shared by all connections or can be // unique for each connection. -type ConnectionCallbacks interface { +type ConnectionCallbacks struct { // The following callbacks will never be called concurrently for the same // connection. They may be called concurrently for different connections. // OnConnected is called when an incoming OpAMP connection is successfully // established after OnConnecting() returns. - OnConnected(ctx context.Context, conn Connection) + OnConnected func(ctx context.Context, conn Connection) // OnMessage is called when a message is received from the connection. Can happen // only after OnConnected(). Must return a ServerToAgent message that will be sent // as a response to the Agent. // For plain HTTP requests once OnMessage returns and the response is sent // to the Agent the OnConnectionClose message will be called immediately. - OnMessage(ctx context.Context, conn Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent + OnMessage func(ctx context.Context, conn Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent // OnConnectionClose is called when the OpAMP connection is closed. - OnConnectionClose(conn Connection) + OnConnectionClose func(conn Connection) +} + +func defaultOnConnected(ctx context.Context, conn Connection) {} + +func defaultOnMessage( + ctx context.Context, conn Connection, message *protobufs.AgentToServer, +) *protobufs.ServerToAgent { + // We will send an empty response since there is no user-defined callback to handle it. + return &protobufs.ServerToAgent{ + InstanceUid: message.InstanceUid, + } +} + +func defaultOnConnectionClose(conn Connection) {} + +func (c *ConnectionCallbacks) SetDefaults() { + if c.OnConnected == nil { + c.OnConnected = defaultOnConnected + } + + if c.OnMessage == nil { + c.OnMessage = defaultOnMessage + } + + if c.OnConnectionClose == nil { + c.OnConnectionClose = defaultOnConnectionClose + } }