From 261872ed92047467656826e77745379c8c23548b Mon Sep 17 00:00:00 2001 From: Muzzammil Shahid Date: Wed, 8 Jan 2025 13:16:20 +0500 Subject: [PATCH] Refactor: make config file loading and validation code reusable --- cmd/xconn/main.go | 59 ++--------------------- {cmd/xconn => util}/authenticator.go | 2 +- {cmd/xconn => util}/config.go | 2 +- {cmd/xconn => util}/types.go | 2 +- util/util.go | 71 ++++++++++++++++++++++++++++ 5 files changed, 77 insertions(+), 59 deletions(-) rename {cmd/xconn => util}/authenticator.go (99%) rename {cmd/xconn => util}/config.go (99%) rename {cmd/xconn => util}/types.go (99%) create mode 100644 util/util.go diff --git a/cmd/xconn/main.go b/cmd/xconn/main.go index 5ba5eba..6cd110f 100644 --- a/cmd/xconn/main.go +++ b/cmd/xconn/main.go @@ -1,22 +1,16 @@ package main import ( - "bytes" _ "embed" // nolint:gci "fmt" - "io" "log" "os" "os/signal" "path/filepath" - "time" "github.com/alecthomas/kingpin/v2" - "golang.org/x/exp/slices" - "gopkg.in/yaml.v3" - "github.com/xconnio/xconn-go" - "github.com/xconnio/xconn-go/internal" + "github.com/xconnio/xconn-go/util" ) var ( @@ -79,56 +73,9 @@ func Run(args []string) error { } case c.start.FullCommand(): - data, err := os.ReadFile(configFile) + closers, err := util.StartServerFromConfigFile(configFile) if err != nil { - return fmt.Errorf("unable to read config file: %w", err) - } - - var decoder = yaml.NewDecoder(bytes.NewBuffer(data)) - decoder.KnownFields(true) - - var config Config - if err := decoder.Decode(&config); err != nil { - return fmt.Errorf("unable to decode config file: %w", err) - } - - if err := config.Validate(); err != nil { - return fmt.Errorf("invalid config: %w", err) - } - - router := xconn.NewRouter() - defer router.Close() - - for _, realm := range config.Realms { - router.AddRealm(realm.Name) - } - - authenticator := NewAuthenticator(config.Authenticators) - - closers := make([]io.Closer, 0) - for _, transport := range config.Transports { - var throttle *internal.Throttle - if transport.RateLimit.Rate > 0 && transport.RateLimit.Interval > 0 { - strategy := internal.Burst - if transport.RateLimit.Strategy == LeakyBucketStrategy { - strategy = internal.LeakyBucket - } - throttle = internal.NewThrottle(transport.RateLimit.Rate, - time.Duration(transport.RateLimit.Interval)*time.Second, strategy) - } - server := xconn.NewServer(router, authenticator, &xconn.ServerConfig{Throttle: throttle}) - if slices.Contains(transport.Serializers, "protobuf") { - if err := server.RegisterSpec(xconn.ProtobufSerializerSpec); err != nil { - return err - } - } - - closer, err := server.Start(transport.Host, transport.Port) - if err != nil { - return err - } - - closers = append(closers, closer) + return err } // Close server if SIGINT (CTRL-c) received. diff --git a/cmd/xconn/authenticator.go b/util/authenticator.go similarity index 99% rename from cmd/xconn/authenticator.go rename to util/authenticator.go index d406473..1b44205 100644 --- a/cmd/xconn/authenticator.go +++ b/util/authenticator.go @@ -1,4 +1,4 @@ -package main +package util import ( "fmt" diff --git a/cmd/xconn/config.go b/util/config.go similarity index 99% rename from cmd/xconn/config.go rename to util/config.go index ca22c2e..7919589 100644 --- a/cmd/xconn/config.go +++ b/util/config.go @@ -1,4 +1,4 @@ -package main +package util import ( "encoding/hex" diff --git a/cmd/xconn/types.go b/util/types.go similarity index 99% rename from cmd/xconn/types.go rename to util/types.go index 3506d10..b8648b4 100644 --- a/cmd/xconn/types.go +++ b/util/types.go @@ -1,4 +1,4 @@ -package main +package util type Realm struct { Name string `yaml:"name"` diff --git a/util/util.go b/util/util.go new file mode 100644 index 0000000..0bd6260 --- /dev/null +++ b/util/util.go @@ -0,0 +1,71 @@ +package util + +import ( + "bytes" + "fmt" + "io" + "os" + "time" + + "golang.org/x/exp/slices" + "gopkg.in/yaml.v3" + + "github.com/xconnio/xconn-go" + "github.com/xconnio/xconn-go/internal" +) + +func StartServerFromConfigFile(configFile string) ([]io.Closer, error) { + data, err := os.ReadFile(configFile) + if err != nil { + return nil, fmt.Errorf("unable to read config file: %w", err) + } + + var decoder = yaml.NewDecoder(bytes.NewBuffer(data)) + decoder.KnownFields(true) + + var config Config + if err := decoder.Decode(&config); err != nil { + return nil, fmt.Errorf("unable to decode config file: %w", err) + } + + if err := config.Validate(); err != nil { + return nil, fmt.Errorf("invalid config: %w", err) + } + + router := xconn.NewRouter() + defer router.Close() + + for _, realm := range config.Realms { + router.AddRealm(realm.Name) + } + + authenticator := NewAuthenticator(config.Authenticators) + + closers := make([]io.Closer, 0) + for _, transport := range config.Transports { + var throttle *internal.Throttle + if transport.RateLimit.Rate > 0 && transport.RateLimit.Interval > 0 { + strategy := internal.Burst + if transport.RateLimit.Strategy == LeakyBucketStrategy { + strategy = internal.LeakyBucket + } + throttle = internal.NewThrottle(transport.RateLimit.Rate, + time.Duration(transport.RateLimit.Interval)*time.Second, strategy) + } + server := xconn.NewServer(router, authenticator, &xconn.ServerConfig{Throttle: throttle}) + if slices.Contains(transport.Serializers, "protobuf") { + if err := server.RegisterSpec(xconn.ProtobufSerializerSpec); err != nil { + return nil, err + } + } + + closer, err := server.Start(transport.Host, transport.Port) + if err != nil { + return nil, err + } + + closers = append(closers, closer) + } + + return closers, nil +}