From 6dc3743435698b3f7a2e7043719a302d0c53d21b Mon Sep 17 00:00:00 2001 From: Owen Cabalceta Date: Mon, 18 Mar 2024 06:14:08 -0400 Subject: [PATCH] feat: connect to xmidt cluster feat: connect to xmidt cluster --- .gitignore | 5 +- MAINTAINERS.md | 2 +- cmd/xmidt-agent/config.go | 37 +++++++ cmd/xmidt-agent/main.go | 110 ++++++++++++++++---- cmd/xmidt-agent/ws.go | 173 +++++++++++++++++++++++++++++++ cmd/xmidt-agent/xmidt_agent.yaml | 27 +++++ go.mod | 11 +- go.sum | 24 ++--- internal/websocket/options.go | 74 +++++++++++++ internal/websocket/ws.go | 43 ++++++-- 10 files changed, 456 insertions(+), 50 deletions(-) create mode 100644 cmd/xmidt-agent/ws.go create mode 100644 cmd/xmidt-agent/xmidt_agent.yaml diff --git a/.gitignore b/.gitignore index aa16cb9..ca0d682 100644 --- a/.gitignore +++ b/.gitignore @@ -15,7 +15,10 @@ *.out # VS Code directories -.vscode +*.code-workspace +.vscode/* +.dev/* +__debug_bin* # Dependency directories (remove the comment below to include it) # vendor/ diff --git a/MAINTAINERS.md b/MAINTAINERS.md index ad4e29e..c9bea7d 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -1,6 +1,6 @@ Maintainers of this repository: * Weston Schmidt @schmidtw -* Joel Unzain @joe94 +* Owen Cabalceta @denopink * John Bass @johnabass * Nick Harter @njharter diff --git a/cmd/xmidt-agent/config.go b/cmd/xmidt-agent/config.go index 42f370b..fd84f2f 100644 --- a/cmd/xmidt-agent/config.go +++ b/cmd/xmidt-agent/config.go @@ -11,6 +11,7 @@ import ( "github.com/goschtalt/goschtalt" "github.com/xmidt-org/arrange/arrangehttp" + "github.com/xmidt-org/retry" "github.com/xmidt-org/sallust" "github.com/xmidt-org/wrp-go/v3" "gopkg.in/dealancer/validate.v2" @@ -18,6 +19,7 @@ import ( // Config is the configuration for the xmidt-agent. type Config struct { + Client Client Identity Identity OperationalState OperationalState XmidtCredentials XmidtCredentials @@ -26,6 +28,41 @@ type Config struct { Storage Storage } +type Client struct { + // FetchURL is the url used to fetch the WS url. + FetchURL string + // (optional) FetchURLTimeout is the timeout for the fetching the WS url. If this is not set, the default is 30 seconds. + FetchURLTimeout time.Duration + // (optional) PingInterval is the ping interval allowed for the WS connection. + PingInterval time.Duration + // (optional) PingTimeout is the ping timeout for the WS connection. + PingTimeout time.Duration + // (optional) ConnectTimeout is the connect timeout for the WS connection. + ConnectTimeout time.Duration + // (optional) KeepAliveInterval is the keep alive interval for the WS connection. + KeepAliveInterval time.Duration + // (optional) IdleConnTimeout is the idle connection timeout for the WS connection. + IdleConnTimeout time.Duration + // (optional) TLSHandshakeTimeout is the TLS handshake timeout for the WS connection. + TLSHandshakeTimeout time.Duration + // (optional) ExpectContinueTimeout is the expect continue timeout for the WS connection. + ExpectContinueTimeout time.Duration + // (optional) MaxMessageBytes is the largest allowable message to send or receive. + MaxMessageBytes int64 + // (optional) DisableV4 determines whether or not to allow IPv4 for the WS connection. + // If this is not set, the default is false (IPv4 is enabled). + // Either V4 or V6 can be disabled, but not both. + DisableV4 bool + // (optional) DisableV6 determines whether or not to allow IPv6 for the WS connection. + // If this is not set, the default is false (IPv6 is enabled). + // Either V4 or V6 can be disabled, but not both. + DisableV6 bool + // (optional) RetryPolicy sets the retry policy factory used for delaying between retry attempts for reconnection. + RetryPolicy retry.Config + // (optional) Once sets whether or not to only attempt to connect once. + Once bool +} + // Identity contains the information that identifies the device. type Identity struct { // DeviceID is the unique identifier for the device. Generally this is a diff --git a/cmd/xmidt-agent/main.go b/cmd/xmidt-agent/main.go index 3230a93..79dd0e6 100644 --- a/cmd/xmidt-agent/main.go +++ b/cmd/xmidt-agent/main.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "os" + "runtime/debug" "github.com/alecthomas/kong" "github.com/goschtalt/goschtalt" @@ -14,8 +15,11 @@ import ( _ "github.com/goschtalt/yaml-decoder" _ "github.com/goschtalt/yaml-encoder" "github.com/xmidt-org/sallust" + "github.com/xmidt-org/wrp-go/v3" "github.com/xmidt-org/xmidt-agent/internal/credentials" "github.com/xmidt-org/xmidt-agent/internal/jwtxt" + "github.com/xmidt-org/xmidt-agent/internal/websocket" + "github.com/xmidt-org/xmidt-agent/internal/websocket/event" "go.uber.org/fx" "go.uber.org/fx/fxevent" @@ -43,6 +47,15 @@ type CLI struct { Files []string `optional:"" short:"f" help:"Specific configuration files or directories."` } +type LifeCycleIn struct { + fx.In + Logger *zap.Logger + LC fx.Lifecycle + Shutdowner fx.Shutdowner + WS *websocket.Websocket + CancelList []event.CancelFunc +} + // xmidtAgent is the main entry point for the program. It is responsible for // setting up the dependency injection framework and returning the app object. func xmidtAgent(args []string) (*fx.App, error) { @@ -72,6 +85,10 @@ func xmidtAgent(args []string) (*fx.App, error) { provideConfig, provideCredentials, provideInstructions, + provideWS, + func(id Identity) wrp.DeviceID { + return id.DeviceID + }, goschtalt.UnmarshalFunc[sallust.Config]("logger", goschtalt.Optional()), goschtalt.UnmarshalFunc[Identity]("identity"), @@ -79,6 +96,7 @@ func xmidtAgent(args []string) (*fx.App, error) { goschtalt.UnmarshalFunc[XmidtCredentials]("xmidt_credentials"), goschtalt.UnmarshalFunc[XmidtService]("xmidt_service"), goschtalt.UnmarshalFunc[Storage]("storage"), + goschtalt.UnmarshalFunc[Client]("client"), ), fsProvide(), @@ -96,6 +114,7 @@ func xmidtAgent(args []string) (*fx.App, error) { fmt.Println(s) } }, + lifeCycle, ), ) @@ -170,29 +189,76 @@ func provideCLIWithOpts(args cliArgs, testOpts bool) (*CLI, error) { return &cli, nil } +type LoggerIn struct { + fx.In + CLI *CLI + Cfg sallust.Config +} + // Create the logger and configure it based on if the program is in // debug mode or normal mode. -func provideLogger(cli *CLI, cfg sallust.Config) (*zap.Logger, error) { - if cli.Dev { - cfg.Level = "DEBUG" - cfg.Development = true - cfg.Encoding = "console" - cfg.EncoderConfig = sallust.EncoderConfig{ - TimeKey: "T", - LevelKey: "L", - NameKey: "N", - CallerKey: "C", - FunctionKey: zapcore.OmitKey, - MessageKey: "M", - StacktraceKey: "S", - LineEnding: zapcore.DefaultLineEnding, - EncodeLevel: "capitalColor", - EncodeTime: "RFC3339", - EncodeDuration: "string", - EncodeCaller: "short", - } - cfg.OutputPaths = []string{"stderr"} - cfg.ErrorOutputPaths = []string{"stderr"} +func provideLogger(in LoggerIn) (*zap.Logger, error) { + in.Cfg.EncoderConfig = sallust.EncoderConfig{ + TimeKey: "T", + LevelKey: "L", + NameKey: "N", + CallerKey: "C", + FunctionKey: zapcore.OmitKey, + MessageKey: "M", + StacktraceKey: "S", + LineEnding: zapcore.DefaultLineEnding, + EncodeLevel: "capitalColor", + EncodeTime: "RFC3339", + EncodeDuration: "string", + EncodeCaller: "short", + } + + if in.CLI.Dev { + in.Cfg.Level = "DEBUG" + in.Cfg.Development = true + in.Cfg.Encoding = "console" + in.Cfg.OutputPaths = append(in.Cfg.OutputPaths, "stderr") + in.Cfg.ErrorOutputPaths = append(in.Cfg.ErrorOutputPaths, "stderr") } - return cfg.Build() + + return in.Cfg.Build() +} + +func lifeCycle(in LifeCycleIn) { + logger := in.Logger.With(zap.String("component", "fx_lifecycle")) + in.LC.Append( + fx.Hook{ + OnStart: func(ctx context.Context) error { + defer func() { + if r := recover(); nil != r { + logger.Error("stacktrace from panic", zap.String("stacktrace", string(debug.Stack())), zap.Any("panic", r)) + } + }() + + logger.Info("starting ws") + in.WS.Start() + + return nil + }, + OnStop: func(ctx context.Context) error { + defer func() { + if r := recover(); nil != r { + logger.Error("stacktrace from panic", zap.String("stacktrace", string(debug.Stack())), zap.Any("panic", r)) + } + + if err := in.Shutdowner.Shutdown(); err != nil { + logger.Error("encountered error trying to shutdown app: ", zap.Error(err)) + } + }() + + logger.Info("stopping ws listeners") + in.WS.Stop() + for _, c := range in.CancelList { + c() + } + + return nil + }, + }, + ) } diff --git a/cmd/xmidt-agent/ws.go b/cmd/xmidt-agent/ws.go new file mode 100644 index 0000000..b062cac --- /dev/null +++ b/cmd/xmidt-agent/ws.go @@ -0,0 +1,173 @@ +package main + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + "time" + + "github.com/xmidt-org/wrp-go/v3" + "github.com/xmidt-org/wrp-go/v3/wrphttp" + "github.com/xmidt-org/xmidt-agent/internal/websocket" + "github.com/xmidt-org/xmidt-agent/internal/websocket/event" + "go.uber.org/fx" + "go.uber.org/zap" +) + +var ( + ErrClientConfig = errors.New("client configuration error") +) + +type wsIn struct { + fx.In + DeviceID wrp.DeviceID + Logger *zap.Logger + CLI *CLI + Client Client +} + +type wsOut struct { + fx.Out + WS *websocket.Websocket + CancelList []event.CancelFunc +} + +func provideWS(in wsIn) (wsOut, error) { + opts := []websocket.Option{ + websocket.DeviceID(in.DeviceID), + websocket.FetchURL(func(ctx context.Context) (string, error) { + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + req, err := http.NewRequestWithContext(ctx, "GET", in.Client.FetchURL, nil) + if err != nil { + return "", fmt.Errorf("failed to fetch ws url: %s", err) + } + + req.Header.Set(wrphttp.DestinationHeader, string(in.DeviceID.Bytes())) + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("failed to fetch ws url: %s: %s", getDoErrReason(err), err) + } + + defer resp.Body.Close() + if resp.StatusCode != http.StatusTemporaryRedirect { + respBody, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("failed to fetch ws url: unexpected status code (expected: %d): %d: %s", http.StatusTemporaryRedirect, resp.StatusCode, respBody) + } + + url, err := resp.Location() + if err != nil { + return "", fmt.Errorf("failed to fetch ws url: %s", err) + } + + return url.String(), nil + }), + websocket.RetryPolicy(in.Client.RetryPolicy), + websocket.Logger(in.Logger), + websocket.NowFunc(time.Now), + websocket.FetchURLTimeout(in.Client.FetchURLTimeout), + websocket.PingInterval(in.Client.PingInterval), + websocket.PingTimeout(in.Client.PingTimeout), + websocket.ConnectTimeout(in.Client.ConnectTimeout), + websocket.KeepAliveInterval(in.Client.KeepAliveInterval), + websocket.IdleConnTimeout(in.Client.IdleConnTimeout), + websocket.TLSHandshakeTimeout(in.Client.TLSHandshakeTimeout), + websocket.ExpectContinueTimeout(in.Client.ExpectContinueTimeout), + websocket.MaxMessageBytes(in.Client.MaxMessageBytes), + websocket.WithIPv6(!in.Client.DisableV6), + websocket.WithIPv4(!in.Client.DisableV4), + websocket.Once(in.Client.Once), + } + + var msg, con, discon event.CancelFunc + if in.CLI.Dev { + opts = append(opts, + websocket.AddMessageListener( + event.MsgListenerFunc( + func(m wrp.Message) { + in.Logger.Info("message listener", zap.Any("msg", m)) + }), &msg), + websocket.AddConnectListener( + event.ConnectListenerFunc( + func(e event.Connect) { + in.Logger.Info("connect listener", zap.Any("event", e)) + }), &con), + websocket.AddDisconnectListener( + event.DisconnectListenerFunc( + func(e event.Disconnect) { + in.Logger.Info("disconnect listener", zap.Any("event", e)) + }), &discon), + ) + } + + if in.Client.FetchURL == "" { + return wsOut{}, fmt.Errorf("%w: client FetchURL can't be empty", ErrClientConfig) + } + + ws, err := websocket.New(opts...) + return wsOut{ + WS: ws, + CancelList: []event.CancelFunc{msg, con, discon}, + }, err +} + +const ( + genericDoReason = "do_error" + deadlineExceededReason = "context_deadline_exceeded" + contextCanceledReason = "context_canceled" + addressErrReason = "address_error" + parseAddrErrReason = "parse_address_error" + invalidAddrReason = "invalid_address" + dnsErrReason = "dns_error" + hostNotFoundReason = "host_not_found" + connClosedReason = "connection_closed" + opErrReason = "op_error" + networkErrReason = "unknown_network_err" + connectionUnexpectedlyClosedEOFReason = "connection_unexpectedly_closed_eof" + noErrReason = "no_err" +) + +func getDoErrReason(err error) string { + var d *net.DNSError + if err == nil { + return noErrReason + } else if errors.Is(err, context.DeadlineExceeded) { + return deadlineExceededReason + } else if errors.Is(err, context.Canceled) { + return contextCanceledReason + } else if errors.Is(err, &net.AddrError{}) { + return addressErrReason + } else if errors.Is(err, &net.ParseError{}) { + return parseAddrErrReason + } else if errors.Is(err, net.InvalidAddrError("")) { + return invalidAddrReason + } else if errors.As(err, &d) { + if d.IsNotFound { + return hostNotFoundReason + } + return dnsErrReason + } else if errors.Is(err, net.ErrClosed) { + return connClosedReason + } else if errors.Is(err, &net.OpError{}) { + return opErrReason + } else if errors.Is(err, net.UnknownNetworkError("")) { + return networkErrReason + } + + // nolint: errorlint + if err, ok := err.(*url.Error); ok { + if strings.TrimSpace(strings.ToLower(err.Unwrap().Error())) == "eof" { + return connectionUnexpectedlyClosedEOFReason + } + } + + return genericDoReason +} diff --git a/cmd/xmidt-agent/xmidt_agent.yaml b/cmd/xmidt-agent/xmidt_agent.yaml new file mode 100644 index 0000000..b6a6e5b --- /dev/null +++ b/cmd/xmidt-agent/xmidt_agent.yaml @@ -0,0 +1,27 @@ +client: + fetch_url: https://localhost:8080/api/v2/device +xmidt_credentials: + url: https://localhost:8080/issue + file_name: crt.pem + file_permissions: 0777 + http_client: + tls: + insecure_skip_verify: true + certificates: + - certificate_file: crt.pem + key_file: key.pem + min_version: 771 # 0x0303, the TLS 1.2 version uint16 +identity: + device_id: mac:00deadbeef00 + serial_number: 1800deadbeef + hardware_model: fooModel + hardware_manufacturer: barManufacturer + firmware_version: v0.0.1 + partner_id: foobar +operational_state: + last_reboot_reason: sleepy + boot_time: "2024-02-28T01:04:27Z" +# Optional +# storage: +# temporary: ~/local-rdk-testing/temporary +# durable: ~/local-rdk-testing/durable \ No newline at end of file diff --git a/go.mod b/go.mod index 0dcf698..83b93d8 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/xmidt-org/xmidt-agent -go 1.21.0 +go 1.21.8 require ( github.com/alecthomas/kong v0.9.0 @@ -13,7 +13,7 @@ require ( github.com/stretchr/testify v1.9.0 github.com/ugorji/go/codec v1.2.12 github.com/xmidt-org/arrange v0.5.0 - github.com/xmidt-org/eventor v0.0.0-20230910205925-8ff168bd12ed + github.com/xmidt-org/eventor v0.0.0-20240304051151-5d41136e8fdd github.com/xmidt-org/retry v0.0.3 github.com/xmidt-org/sallust v0.2.2 github.com/xmidt-org/wrp-go/v3 v3.5.1 @@ -25,8 +25,11 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-kit/kit v0.13.0 // indirect + github.com/go-kit/log v0.2.1 // indirect + github.com/go-logfmt/logfmt v0.6.0 // indirect github.com/goschtalt/approx v1.0.0 // indirect - github.com/leodido/go-urn v1.2.4 // indirect + github.com/leodido/go-urn v1.4.0 // indirect github.com/miekg/dns v1.1.57 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect @@ -36,7 +39,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/mod v0.14.0 // indirect golang.org/x/net v0.18.0 // indirect - golang.org/x/sys v0.16.0 // indirect + golang.org/x/sys v0.18.0 // indirect golang.org/x/tools v0.15.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index bd0de90..3061c4b 100644 --- a/go.sum +++ b/go.sum @@ -9,6 +9,12 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/foxcpp/go-mockdns v1.1.0 h1:jI0rD8M0wuYAxL7r/ynTrCQQq0BVqfB99Vgk7DlmewI= github.com/foxcpp/go-mockdns v1.1.0/go.mod h1:IhLeSFGed3mJIAXPH2aiRQB+kqz7oqu8ld2qVbOu7Wk= +github.com/go-kit/kit v0.13.0 h1:OoneCcHKHQ03LfBpoQCUfCluwd2Vt3ohz+kvbJneZAU= +github.com/go-kit/kit v0.13.0/go.mod h1:phqEHMMUbyrCFCTgH48JueqrM3md2HcAZ8N3XE4FKDg= +github.com/go-kit/log v0.2.1 h1:MRVx0/zhvdseW+Gza6N9rVzU/IVzaeE1SFI4raAhmBU= +github.com/go-kit/log v0.2.1/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0= +github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4= +github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= @@ -30,8 +36,8 @@ github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3x github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= -github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= -github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/miekg/dns v1.1.57 h1:Jzi7ApEIzwEPLHWRcafCN9LZSBbqQpxjt/wpgvg7wcM= github.com/miekg/dns v1.1.57/go.mod h1:uqRjCRUuEAA6qsOiJvDd+CFo/vW+y5WR6SNmHE55hZk= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= @@ -43,22 +49,17 @@ github.com/psanford/memfs v0.0.0-20210214183328-a001468d78ef/go.mod h1:tcaRap0jS github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg= github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/xmidt-org/arrange v0.5.0 h1:ajkVHkr7dXnfCYm/6eafWoOab+6A3b2jEHQO0IdIIb0= github.com/xmidt-org/arrange v0.5.0/go.mod h1:PoZB9lR49ma0osydQbaWpNeA3XPoLkjP5RYUoOw8wZU= -github.com/xmidt-org/eventor v0.0.0-20230910205925-8ff168bd12ed h1:KpcgFuumKrt/824H3gtmNI/IvgjsBo6rnlSnwXlFu60= -github.com/xmidt-org/eventor v0.0.0-20230910205925-8ff168bd12ed/go.mod h1:X9Og+8y1Llz7N8F20UmjZUNgrxHubMVfBcroJ5SPtIY= +github.com/xmidt-org/eventor v0.0.0-20240304051151-5d41136e8fdd h1:jHUwcgGICAYTlvMo3a1yaIiSgu3sypj3POyRNv2g35o= +github.com/xmidt-org/eventor v0.0.0-20240304051151-5d41136e8fdd/go.mod h1:NpaRwPEiiaB5oEdFI41o6Lf4iQHAVwCdtwKb3z7R8mY= github.com/xmidt-org/httpaux v0.4.0 h1:cAL/MzIBpSsv4xZZeq/Eu1J5M3vfNe49xr41mP3COKU= github.com/xmidt-org/httpaux v0.4.0/go.mod h1:UypqZwuZV1nn8D6+K1JDb+im9IZrLNg/2oO/Bgiybxc= github.com/xmidt-org/retry v0.0.3 h1:wvmBnEEn1OKwSZaQtr1RZ2Vey8JIvP72mGTgR+3wPiM= @@ -114,8 +115,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= -golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -147,7 +148,6 @@ gopkg.in/dealancer/validate.v2 v2.1.0 h1:XY95SZhVH1rBe8uwtnQEsOO79rv8GPwK+P3VWhQ gopkg.in/dealancer/validate.v2 v2.1.0/go.mod h1:EipWMj8hVO2/dPXVlYRe9yKcgVd5OttpQDiM1/wZ0DE= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= nhooyr.io/websocket v1.8.10 h1:mv4p+MnGrLDcPlBoWsvPP7XCzTYMXP9F9eIGoKbgx7Q= diff --git a/internal/websocket/options.go b/internal/websocket/options.go index 5f0982e..3f83541 100644 --- a/internal/websocket/options.go +++ b/internal/websocket/options.go @@ -10,8 +10,22 @@ import ( "time" "github.com/xmidt-org/retry" + "github.com/xmidt-org/sallust" "github.com/xmidt-org/wrp-go/v3" "github.com/xmidt-org/xmidt-agent/internal/websocket/event" + "go.uber.org/zap" +) + +const ( + expectContinueTimeoutDefault = 1 * time.Second + idleConnTimeoutDefault = 10 * time.Second + tlsHandshakeTimeoutDefault + fetchUrlTimeoutDefault = 30 * time.Second + pingIntervalDefault + connectionTimeoutDefault + keepAliveIntervalDefault + pingTimeoutDefault = 90 * time.Second + maxMessageBytesDefault = 256 * 1024 ) // DeviceID sets the device ID for the WS connection. @@ -54,7 +68,10 @@ func FetchURLTimeout(d time.Duration) Option { func(ws *Websocket) error { if d < 0 { return fmt.Errorf("%w: negative FetchURLTimeout", ErrMisconfiguredWS) + } else if d == 0 { + d = fetchUrlTimeoutDefault } + ws.urlFetchingTimeout = d return nil }) @@ -74,6 +91,12 @@ func CredentialsDecorator(f func(http.Header) error) Option { func PingInterval(d time.Duration) Option { return optionFunc( func(ws *Websocket) error { + if d < 0 { + return fmt.Errorf("%w: negative PingInterval", ErrMisconfiguredWS) + } else if d == 0 { + d = pingIntervalDefault + } + ws.pingInterval = d return nil }) @@ -84,6 +107,12 @@ func PingInterval(d time.Duration) Option { func PingTimeout(d time.Duration) Option { return optionFunc( func(ws *Websocket) error { + if d < 0 { + return fmt.Errorf("%w: negative PingTimeout", ErrMisconfiguredWS) + } else if d == 0 { + d = pingTimeoutDefault + } + ws.pingTimeout = d return nil }) @@ -94,6 +123,12 @@ func PingTimeout(d time.Duration) Option { func KeepAliveInterval(d time.Duration) Option { return optionFunc( func(ws *Websocket) error { + if d < 0 { + return fmt.Errorf("%w: negative KeepAliveInterval", ErrMisconfiguredWS) + } else if d == 0 { + d = keepAliveIntervalDefault + } + ws.keepAliveInterval = d return nil }) @@ -104,6 +139,12 @@ func KeepAliveInterval(d time.Duration) Option { func TLSHandshakeTimeout(d time.Duration) Option { return optionFunc( func(ws *Websocket) error { + if d < 0 { + return fmt.Errorf("%w: negative TLSHandshakeTimeout", ErrMisconfiguredWS) + } else if d == 0 { + d = tlsHandshakeTimeoutDefault + } + ws.tlsHandshakeTimeout = d return nil }) @@ -114,6 +155,12 @@ func TLSHandshakeTimeout(d time.Duration) Option { func IdleConnTimeout(d time.Duration) Option { return optionFunc( func(ws *Websocket) error { + if d < 0 { + return fmt.Errorf("%w: negative IdleConnTimeout", ErrMisconfiguredWS) + } else if d == 0 { + d = idleConnTimeoutDefault + } + ws.idleConnTimeout = d return nil }) @@ -124,6 +171,12 @@ func IdleConnTimeout(d time.Duration) Option { func ExpectContinueTimeout(d time.Duration) Option { return optionFunc( func(ws *Websocket) error { + if d < 0 { + return fmt.Errorf("%w: negative ExpectContinueTimeout", ErrMisconfiguredWS) + } else if d == 0 { + d = expectContinueTimeoutDefault + } + ws.expectContinueTimeout = d return nil }) @@ -158,7 +211,10 @@ func ConnectTimeout(d time.Duration) Option { func(ws *Websocket) error { if d < 0 { return fmt.Errorf("%w: negative ConnectTimeout", ErrMisconfiguredWS) + } else if d == 0 { + d = connectionTimeoutDefault } + ws.connectTimeout = d return nil }) @@ -187,6 +243,19 @@ func Once(once ...bool) Option { }) } +// Logger sets the zap logger. +func Logger(l *zap.Logger) Option { + return optionFunc( + func(ws *Websocket) error { + if l == nil { + l = sallust.Default() + } + + ws.l = l + return nil + }) +} + // NowFunc sets the now function for the WS connection. func NowFunc(f func() time.Time) Option { return optionFunc( @@ -194,6 +263,7 @@ func NowFunc(f func() time.Time) Option { if f == nil { return fmt.Errorf("%w: nil NowFunc", ErrMisconfiguredWS) } + ws.nowFunc = f return nil }) @@ -213,6 +283,10 @@ func RetryPolicy(pf retry.PolicyFactory) Option { func MaxMessageBytes(bytes int64) Option { return optionFunc( func(ws *Websocket) error { + if bytes == 0 { + bytes = maxMessageBytesDefault + } + ws.maxMessageBytes = bytes return nil }) diff --git a/internal/websocket/ws.go b/internal/websocket/ws.go index be788f2..74ea1b3 100644 --- a/internal/websocket/ws.go +++ b/internal/websocket/ws.go @@ -15,6 +15,7 @@ import ( "github.com/xmidt-org/retry" "github.com/xmidt-org/wrp-go/v3" "github.com/xmidt-org/xmidt-agent/internal/websocket/event" + "go.uber.org/zap" nhws "nhooyr.io/websocket" ) @@ -24,6 +25,11 @@ var ( ErrInvalidMsgType = errors.New("invalid message type") ) +// emptyBuffer is solely used as an address of a global empty buffer. +// This sentinel value will reset pointers of the writePump's encoder +// such that the gc can clean things up. +var emptyBuffer = []byte{} + type Websocket struct { // id is the device ID for the WS connection. id wrp.DeviceID @@ -94,6 +100,7 @@ type Websocket struct { m sync.Mutex wg sync.WaitGroup shutdown context.CancelFunc + l *zap.Logger conn *nhws.Conn } @@ -115,15 +122,15 @@ func New(opts ...Option) (*Websocket, error) { defaults := []Option{ NowFunc(time.Now), - FetchURLTimeout(30 * time.Second), - PingInterval(30 * time.Second), - PingTimeout(90 * time.Second), - ConnectTimeout(30 * time.Second), - KeepAliveInterval(30 * time.Second), - IdleConnTimeout(10 * time.Second), - TLSHandshakeTimeout(10 * time.Second), - ExpectContinueTimeout(1 * time.Second), - MaxMessageBytes(256 * 1024), + FetchURLTimeout(fetchUrlTimeoutDefault), + PingInterval(pingIntervalDefault), + PingTimeout(pingTimeoutDefault), + ConnectTimeout(connectionTimeoutDefault), + KeepAliveInterval(keepAliveIntervalDefault), + IdleConnTimeout(idleConnTimeoutDefault), + TLSHandshakeTimeout(tlsHandshakeTimeoutDefault), + ExpectContinueTimeout(expectContinueTimeoutDefault), + MaxMessageBytes(maxMessageBytesDefault), WithIPv4(), WithIPv6(), Once(false), @@ -225,7 +232,9 @@ func (ws *Websocket) run(ctx context.Context) { ws.wg.Add(1) defer ws.wg.Done() + pingTicker := time.NewTicker(ws.pingInterval) decoder := wrp.NewDecoder(nil, wrp.Msgpack) + encoder := wrp.NewEncoder(nil, wrp.Msgpack) mode := ws.nextMode(ipv4) policy := ws.retryPolicyFactory.NewPolicy(ctx) @@ -258,7 +267,7 @@ func (ws *Websocket) run(ctx context.Context) { // Read loop for { var msg wrp.Message - typ, reader, err := conn.Reader(ctx) + typ, reader, err := ws.conn.Reader(ctx) if err == nil { if typ != nhws.MessageBinary { err = ErrInvalidMsgType @@ -291,6 +300,20 @@ func (ws *Websocket) run(ctx context.Context) { ws.msgListeners.Visit(func(l event.MsgListener) { l.OnMessage(msg) }) + var frameContents []byte + // nolint: typecheck + + // if the request was in a format other than Msgpack, or if the caller did not pass + // Contents, then do the encoding here. + encoder.ResetBytes(&frameContents) + err = encoder.Encode(msg) + encoder.ResetBytes(&emptyBuffer) + if err != nil { + ws.l.Error("xmidt-agent failed to response to wrp message", zap.Error(err)) + continue + } + + ws.conn.Write(ctx, nhws.MessageBinary, frameContents) } }