diff --git a/internal/examples/go.mod b/internal/examples/go.mod index cb4cf630..ee9dd117 100644 --- a/internal/examples/go.mod +++ b/internal/examples/go.mod @@ -9,6 +9,7 @@ require ( github.com/open-telemetry/opamp-go v0.1.0 github.com/shirou/gopsutil v3.21.11+incompatible github.com/stretchr/testify v1.8.4 + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 go.opentelemetry.io/otel v1.24.0 go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.24.0 go.opentelemetry.io/otel/metric v1.24.0 @@ -19,12 +20,13 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.4.9 // indirect github.com/go-logr/logr v1.4.1 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.2.6 // indirect - github.com/gorilla/websocket v1.5.1 // indirect github.com/golang/protobuf v1.5.3 // indirect + github.com/gorilla/websocket v1.5.1 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1 // indirect github.com/mitchellh/copystructure v1.2.0 // indirect github.com/mitchellh/mapstructure v1.4.1 // indirect diff --git a/internal/examples/go.sum b/internal/examples/go.sum index 237c4196..436276a8 100644 --- a/internal/examples/go.sum +++ b/internal/examples/go.sum @@ -21,6 +21,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/go-ldap/ldap v3.0.2+incompatible/go.mod h1:qfd9rJvER9Q0/D/Sqn1DfHRoBp40uXYvFoEVrNEPqRc= @@ -127,6 +129,8 @@ github.com/tklauser/numcpus v0.3.0 h1:ILuRUQBtssgnxw0XXIjKUC56fgnOrFoQQ/4+DeU2bi github.com/tklauser/numcpus v0.3.0/go.mod h1:yFGUr7TUHQRAhyqBcEg0Ge34zDBAsIvJJcyE6boqnA8= github.com/yusufpapurcu/wmi v1.2.2 h1:KBNDSne4vP5mbSWnJbO+51IMOXJB67QiYCSBrubbPRg= github.com/yusufpapurcu/wmi v1.2.2/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw= go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo= go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo= go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.24.0 h1:mM8nKi6/iFQ0iqst80wDHU2ge198Ye/TfN0WBS5U24Y= diff --git a/internal/examples/server/opampsrv/opampsrv.go b/internal/examples/server/opampsrv/opampsrv.go index bc652738..84bb2018 100644 --- a/internal/examples/server/opampsrv/opampsrv.go +++ b/internal/examples/server/opampsrv/opampsrv.go @@ -6,6 +6,8 @@ import ( "net/http" "os" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" + "github.com/open-telemetry/opamp-go/internal" "github.com/open-telemetry/opamp-go/internal/examples/server/data" "github.com/open-telemetry/opamp-go/protobufs" @@ -54,6 +56,7 @@ func (srv *Server) Start() { }, }, ListenEndpoint: "127.0.0.1:4320", + HTTPMiddleware: otelhttp.NewMiddleware("/v1/opamp"), } tlsConfig, err := internal.CreateServerTLSConfig( "../../certs/certs/ca.cert.pem", diff --git a/server/server.go b/server/server.go index 066edf42..46acff9f 100644 --- a/server/server.go +++ b/server/server.go @@ -42,6 +42,11 @@ type StartSettings struct { // Server's TLS configuration. TLSConfig *tls.Config + + // HTTPMiddleware specifies middleware for HTTP messages received by the server. + // Note that the function will be called once for websockets upon connecting and will + // be called for every HTTP request. This function is optional to set. + HTTPMiddleware func(handler http.Handler) http.Handler } type HTTPHandlerFunc func(http.ResponseWriter, *http.Request) diff --git a/server/serverimpl.go b/server/serverimpl.go index 10b159c3..2f5965a0 100644 --- a/server/serverimpl.go +++ b/server/serverimpl.go @@ -47,6 +47,16 @@ type server struct { var _ OpAMPServer = (*server)(nil) +// innerHTTPHandler implements the http.Handler interface so it can be used by functions +// that require the type (like Middleware) without exposing ServeHTTP directly on server. +type innerHTTPHander struct { + httpHandlerFunc http.HandlerFunc +} + +func (i innerHTTPHander) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + i.httpHandlerFunc(writer, request) +} + // New creates a new OpAMP Server. func New(logger types.Logger) *server { if logger == nil { @@ -82,7 +92,13 @@ func (s *server) Start(settings StartSettings) error { path = defaultOpAMPPath } - mux.HandleFunc(path, s.httpHandler) + handler := innerHTTPHander{s.httpHandler} + + if settings.HTTPMiddleware != nil { + mux.Handle(path, settings.HTTPMiddleware(handler)) + } else { + mux.Handle(path, handler) + } hs := &http.Server{ Handler: mux, diff --git a/server/serverimpl_test.go b/server/serverimpl_test.go index 071a99e1..f17263f7 100644 --- a/server/serverimpl_test.go +++ b/server/serverimpl_test.go @@ -57,6 +57,33 @@ func TestServerStartStop(t *testing.T) { assert.NoError(t, err) } +func TestServerStartStopWithMiddleware(t *testing.T) { + var addedMiddleware atomic.Bool + assert.False(t, addedMiddleware.Load()) + + testHTTPMiddleware := func(handler http.Handler) http.Handler { + addedMiddleware.Store(true) + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + handler.ServeHTTP(w, r) + }, + ) + } + + startSettings := &StartSettings{ + HTTPMiddleware: testHTTPMiddleware, + } + + srv := startServer(t, startSettings) + assert.True(t, addedMiddleware.Load()) + + err := srv.Start(*startSettings) + assert.ErrorIs(t, err, errAlreadyStarted) + + err = srv.Stop(context.Background()) + assert.NoError(t, err) +} + func TestServerAddrWithNonZeroPort(t *testing.T) { srv := New(&sharedinternal.NopLogger{}) require.NotNil(t, srv) @@ -830,6 +857,109 @@ func TestConnectionAllowsConcurrentWrites(t *testing.T) { } } +func TestServerCallsHTTPMiddlewareOverWebsocket(t *testing.T) { + middlewareCalled := int32(0) + + testHTTPMiddleware := func(handler http.Handler) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&middlewareCalled, 1) + handler.ServeHTTP(w, r) + }, + ) + } + + callbacks := CallbacksStruct{ + OnConnectingFunc: func(request *http.Request) types.ConnectionResponse { + return types.ConnectionResponse{ + Accept: true, + ConnectionCallbacks: ConnectionCallbacksStruct{}, + } + }, + } + + // Start a Server + settings := &StartSettings{ + HTTPMiddleware: testHTTPMiddleware, + Settings: Settings{Callbacks: callbacks}, + } + srv := startServer(t, settings) + defer func() { + err := srv.Stop(context.Background()) + assert.NoError(t, err) + }() + + // Connect to the server, ensuring successful connection + conn, resp, err := dialClient(settings) + assert.NoError(t, err) + assert.NotNil(t, conn) + require.NotNil(t, resp) + assert.EqualValues(t, 101, resp.StatusCode) + + // Verify middleware was called once for the websocket connection + eventually(t, func() bool { return atomic.LoadInt32(&middlewareCalled) == int32(1) }) + assert.Equal(t, int32(1), atomic.LoadInt32(&middlewareCalled)) +} + +func TestServerCallsHTTPMiddlewareOverHTTP(t *testing.T) { + middlewareCalled := int32(0) + + testHTTPMiddleware := func(handler http.Handler) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&middlewareCalled, 1) + handler.ServeHTTP(w, r) + }, + ) + } + + callbacks := CallbacksStruct{ + OnConnectingFunc: func(request *http.Request) types.ConnectionResponse { + return types.ConnectionResponse{ + Accept: true, + ConnectionCallbacks: ConnectionCallbacksStruct{}, + } + }, + } + + // Start a Server + settings := &StartSettings{ + HTTPMiddleware: testHTTPMiddleware, + Settings: Settings{Callbacks: callbacks}, + } + srv := startServer(t, settings) + defer func() { + err := srv.Stop(context.Background()) + assert.NoError(t, err) + }() + + // Send an AgentToServer message to the Server + sendMsg1 := protobufs.AgentToServer{InstanceUid: "01BX5ZZKBKACTAV9WEVGEMMVS1"} + serializedProtoBytes1, err := proto.Marshal(&sendMsg1) + require.NoError(t, err) + _, err = http.Post( + "http://"+settings.ListenEndpoint+settings.ListenPath, + contentTypeProtobuf, + bytes.NewReader(serializedProtoBytes1), + ) + require.NoError(t, err) + + // Send another AgentToServer message to the Server + sendMsg2 := protobufs.AgentToServer{InstanceUid: "01BX5ZZKBKACTAV9WEVGEMMVRZ"} + serializedProtoBytes2, err := proto.Marshal(&sendMsg2) + require.NoError(t, err) + _, err = http.Post( + "http://"+settings.ListenEndpoint+settings.ListenPath, + contentTypeProtobuf, + bytes.NewReader(serializedProtoBytes2), + ) + require.NoError(t, err) + + // Verify middleware was triggered for each HTTP call + eventually(t, func() bool { return atomic.LoadInt32(&middlewareCalled) == int32(2) }) + assert.Equal(t, int32(2), atomic.LoadInt32(&middlewareCalled)) +} + func BenchmarkSendToClient(b *testing.B) { clientConnections := []*websocket.Conn{} serverConnections := []types.Connection{}