diff --git a/go.mod b/go.mod index 41c3b7e68..91bd78b80 100644 --- a/go.mod +++ b/go.mod @@ -4,23 +4,18 @@ go 1.22 require ( github.com/creasty/defaults v1.7.0 - github.com/go-sql-driver/mysql v1.8.1 github.com/goccy/go-yaml v1.11.3 github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.6.0 github.com/icinga/icinga-go-library v0.0.0-20240522094431-f6f2eb363d60 github.com/jessevdk/go-flags v1.5.0 github.com/jmoiron/sqlx v1.4.0 - github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.22 github.com/okzk/sdnotify v0.0.0-20180710141335-d9becc38acbd github.com/pkg/errors v0.9.1 - github.com/redis/go-redis/v9 v9.5.1 - github.com/ssgreg/journald v1.0.0 github.com/stretchr/testify v1.9.0 github.com/vbauerster/mpb/v6 v6.0.4 go.uber.org/zap v1.27.0 - golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/sync v0.7.0 ) @@ -32,12 +27,17 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/fatih/color v1.16.0 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.12 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/redis/go-redis/v9 v9.5.1 // indirect github.com/rivo/uniseg v0.2.0 // indirect + github.com/ssgreg/journald v1.0.0 // indirect go.uber.org/multierr v1.11.0 // indirect + golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect golang.org/x/sys v0.14.0 // indirect golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/pkg/backoff/backoff.go b/pkg/backoff/backoff.go deleted file mode 100644 index e79a1ee7d..000000000 --- a/pkg/backoff/backoff.go +++ /dev/null @@ -1,43 +0,0 @@ -package backoff - -import ( - "math/rand" - "time" -) - -// Backoff returns the backoff duration for a specific retry attempt. -type Backoff func(uint64) time.Duration - -// NewExponentialWithJitter returns a backoff implementation that -// exponentially increases the backoff duration for each retry from min, -// never exceeding max. Some randomization is added to the backoff duration. -// It panics if min >= max. -func NewExponentialWithJitter(min, max time.Duration) Backoff { - if min <= 0 { - min = 100 * time.Millisecond - } - if max <= 0 { - max = 10 * time.Second - } - if min >= max { - panic("max must be larger than min") - } - - return func(attempt uint64) time.Duration { - e := min << attempt - if e <= 0 || e > max { - e = max - } - - return time.Duration(jitter(int64(e))) - } -} - -// jitter returns a random integer distributed in the range [n/2..n). -func jitter(n int64) int64 { - if n == 0 { - return 0 - } - - return n/2 + rand.Int63n(n/2) -} diff --git a/pkg/com/atomic.go b/pkg/com/atomic.go deleted file mode 100644 index 316413dfd..000000000 --- a/pkg/com/atomic.go +++ /dev/null @@ -1,38 +0,0 @@ -package com - -import "sync/atomic" - -// Atomic is a type-safe wrapper around atomic.Value. -type Atomic[T any] struct { - v atomic.Value -} - -func (a *Atomic[T]) Load() (_ T, ok bool) { - if v, ok := a.v.Load().(box[T]); ok { - return v.v, true - } - - return -} - -func (a *Atomic[T]) Store(v T) { - a.v.Store(box[T]{v}) -} - -func (a *Atomic[T]) Swap(new T) (old T, ok bool) { - if old, ok := a.v.Swap(box[T]{new}).(box[T]); ok { - return old.v, true - } - - return -} - -func (a *Atomic[T]) CompareAndSwap(old, new T) (swapped bool) { - return a.v.CompareAndSwap(box[T]{old}, box[T]{new}) -} - -// box allows, for the case T is an interface, nil values and values of different specific types implementing T -// to be stored in Atomic[T]#v (bypassing atomic.Value#Store()'s policy) by wrapping it (into a non-interface). -type box[T any] struct { - v T -} diff --git a/pkg/com/bulker.go b/pkg/com/bulker.go deleted file mode 100644 index e4fe7aa9e..000000000 --- a/pkg/com/bulker.go +++ /dev/null @@ -1,166 +0,0 @@ -package com - -import ( - "context" - "golang.org/x/sync/errgroup" - "sync" - "time" -) - -// BulkChunkSplitPolicy is a state machine which tracks the items of a chunk a bulker assembles. -// A call takes an item for the current chunk into account. -// Output true indicates that the state machine was reset first and the bulker -// shall finish the current chunk now (not e.g. once $size is reached) without the given item. -type BulkChunkSplitPolicy[T any] func(T) bool - -type BulkChunkSplitPolicyFactory[T any] func() BulkChunkSplitPolicy[T] - -// NeverSplit returns a pseudo state machine which never demands splitting. -func NeverSplit[T any]() BulkChunkSplitPolicy[T] { - return neverSplit[T] -} - -func neverSplit[T any](T) bool { - return false -} - -// Bulker reads all values from a channel and streams them in chunks into a Bulk channel. -type Bulker[T any] struct { - ch chan []T - ctx context.Context - mu sync.Mutex -} - -// NewBulker returns a new Bulker and starts streaming. -func NewBulker[T any]( - ctx context.Context, ch <-chan T, count int, splitPolicyFactory BulkChunkSplitPolicyFactory[T], -) *Bulker[T] { - b := &Bulker[T]{ - ch: make(chan []T), - ctx: ctx, - mu: sync.Mutex{}, - } - - go b.run(ch, count, splitPolicyFactory) - - return b -} - -// Bulk returns the channel on which the bulks are delivered. -func (b *Bulker[T]) Bulk() <-chan []T { - return b.ch -} - -func (b *Bulker[T]) run(ch <-chan T, count int, splitPolicyFactory BulkChunkSplitPolicyFactory[T]) { - defer close(b.ch) - - bufCh := make(chan T, count) - splitPolicy := splitPolicyFactory() - g, ctx := errgroup.WithContext(b.ctx) - - g.Go(func() error { - defer close(bufCh) - - for { - select { - case v, ok := <-ch: - if !ok { - return nil - } - - bufCh <- v - case <-ctx.Done(): - return ctx.Err() - } - } - }) - - g.Go(func() error { - for done := false; !done; { - buf := make([]T, 0, count) - timeout := time.After(256 * time.Millisecond) - - for drain := true; drain && len(buf) < count; { - select { - case v, ok := <-bufCh: - if !ok { - drain = false - done = true - - break - } - - if splitPolicy(v) { - if len(buf) > 0 { - b.ch <- buf - buf = make([]T, 0, count) - } - - timeout = time.After(256 * time.Millisecond) - } - - buf = append(buf, v) - case <-timeout: - drain = false - case <-ctx.Done(): - return ctx.Err() - } - } - - if len(buf) > 0 { - b.ch <- buf - } - - splitPolicy = splitPolicyFactory() - } - - return nil - }) - - // We don't expect an error here. - // We only use errgroup for the encapsulated use of sync.WaitGroup. - _ = g.Wait() -} - -// Bulk reads all values from a channel and streams them in chunks into a returned channel. -func Bulk[T any]( - ctx context.Context, ch <-chan T, count int, splitPolicyFactory BulkChunkSplitPolicyFactory[T], -) <-chan []T { - if count <= 1 { - return oneBulk(ctx, ch) - } - - return NewBulker(ctx, ch, count, splitPolicyFactory).Bulk() -} - -// oneBulk operates just as NewBulker(ctx, ch, 1, splitPolicy).Bulk(), -// but without the overhead of the actual bulk creation with a buffer channel, timeout and BulkChunkSplitPolicy. -func oneBulk[T any](ctx context.Context, ch <-chan T) <-chan []T { - out := make(chan []T) - go func() { - defer close(out) - - for { - select { - case item, ok := <-ch: - if !ok { - return - } - - select { - case out <- []T{item}: - case <-ctx.Done(): - return - } - case <-ctx.Done(): - return - } - } - }() - - return out -} - -var ( - _ BulkChunkSplitPolicyFactory[struct{}] = NeverSplit[struct{}] -) diff --git a/pkg/com/com.go b/pkg/com/com.go deleted file mode 100644 index 229396409..000000000 --- a/pkg/com/com.go +++ /dev/null @@ -1,101 +0,0 @@ -package com - -import ( - "context" - "github.com/pkg/errors" - "golang.org/x/sync/errgroup" -) - -// Waiter implements the Wait method, -// which blocks until execution is complete. -type Waiter interface { - Wait() error // Wait waits for execution to complete. -} - -// The WaiterFunc type is an adapter to allow the use of ordinary functions as Waiter. -// If f is a function with the appropriate signature, WaiterFunc(f) is a Waiter that calls f. -type WaiterFunc func() error - -// Wait implements the Waiter interface. -func (f WaiterFunc) Wait() error { - return f() -} - -// WaitAsync calls Wait() on the passed Waiter in a new goroutine and -// sends the first non-nil error (if any) to the returned channel. -// The returned channel is always closed when the Waiter is done. -func WaitAsync(w Waiter) <-chan error { - errs := make(chan error, 1) - - go func() { - defer close(errs) - - if e := w.Wait(); e != nil { - errs <- e - } - }() - - return errs -} - -// ErrgroupReceive adds a goroutine to the specified group that -// returns the first non-nil error (if any) from the specified channel. -// If the channel is closed, it will return nil. -func ErrgroupReceive(g *errgroup.Group, err <-chan error) { - g.Go(func() error { - if e := <-err; e != nil { - return e - } - - return nil - }) -} - -// CopyFirst asynchronously forwards all items from input to forward and synchronously returns the first item. -func CopyFirst[T any]( - ctx context.Context, input <-chan T, -) (first T, forward <-chan T, err error) { - var ok bool - select { - case <-ctx.Done(): - var zero T - - return zero, nil, ctx.Err() - case first, ok = <-input: - } - - if !ok { - err = errors.New("can't copy from closed channel") - - return - } - - // Buffer of one because we receive an entity and send it back immediately. - fwd := make(chan T, 1) - fwd <- first - - forward = fwd - - go func() { - defer close(fwd) - - for { - select { - case <-ctx.Done(): - return - case e, ok := <-input: - if !ok { - return - } - - select { - case <-ctx.Done(): - return - case fwd <- e: - } - } - } - }() - - return -} diff --git a/pkg/com/cond.go b/pkg/com/cond.go deleted file mode 100644 index 72ba347c5..000000000 --- a/pkg/com/cond.go +++ /dev/null @@ -1,90 +0,0 @@ -package com - -import ( - "context" - "github.com/pkg/errors" -) - -// Cond implements a channel-based synchronization for goroutines that wait for signals or send them. -// Internally based on a controller loop that handles the synchronization of new listeners and signal propagation, -// which is only started when NewCond is called. Thus the zero value cannot be used. -type Cond struct { - broadcast chan struct{} - done chan struct{} - cancel context.CancelFunc - listeners chan chan struct{} -} - -// NewCond returns a new Cond and starts the controller loop. -func NewCond(ctx context.Context) *Cond { - ctx, cancel := context.WithCancel(ctx) - - c := &Cond{ - broadcast: make(chan struct{}), - cancel: cancel, - done: make(chan struct{}), - listeners: make(chan chan struct{}), - } - - go c.controller(ctx) - - return c -} - -// Broadcast sends a signal to all current listeners by closing the previously returned channel from Wait. -// Panics if the controller loop has already ended. -func (c *Cond) Broadcast() { - select { - case c.broadcast <- struct{}{}: - case <-c.done: - panic(errors.New("condition closed")) - } -} - -// Close stops the controller loop, waits for it to finish, and returns an error if any. -// Implements the io.Closer interface. -func (c *Cond) Close() error { - c.cancel() - <-c.done - - return nil -} - -// Done returns a channel that will be closed when the controller loop has ended. -func (c *Cond) Done() <-chan struct{} { - return c.done -} - -// Wait returns a channel that is closed with the next signal. -// Panics if the controller loop has already ended. -func (c *Cond) Wait() <-chan struct{} { - select { - case l := <-c.listeners: - return l - case <-c.done: - panic(errors.New("condition closed")) - } -} - -// controller loop. -func (c *Cond) controller(ctx context.Context) { - defer close(c.done) - - // Note that the notify channel does not close when the controller loop ends - // in order not to notify pending listeners. - notify := make(chan struct{}) - - for { - select { - case <-c.broadcast: - // Close channel to notify all current listeners. - close(notify) - // Create a new channel for the next listeners. - notify = make(chan struct{}) - case c.listeners <- notify: - // A new listener received the channel. - case <-ctx.Done(): - return - } - } -} diff --git a/pkg/com/counter.go b/pkg/com/counter.go deleted file mode 100644 index 52f9f7ff2..000000000 --- a/pkg/com/counter.go +++ /dev/null @@ -1,48 +0,0 @@ -package com - -import ( - "sync" - "sync/atomic" -) - -// Counter implements an atomic counter. -type Counter struct { - value uint64 - mu sync.Mutex // Protects total. - total uint64 -} - -// Add adds the given delta to the counter. -func (c *Counter) Add(delta uint64) { - atomic.AddUint64(&c.value, delta) -} - -// Inc increments the counter by one. -func (c *Counter) Inc() { - c.Add(1) -} - -// Reset resets the counter to 0 and returns its previous value. -// Does not reset the total value returned from Total. -func (c *Counter) Reset() uint64 { - c.mu.Lock() - defer c.mu.Unlock() - - v := atomic.SwapUint64(&c.value, 0) - c.total += v - - return v -} - -// Total returns the total counter value. -func (c *Counter) Total() uint64 { - c.mu.Lock() - defer c.mu.Unlock() - - return c.total + c.Val() -} - -// Val returns the current counter value. -func (c *Counter) Val() uint64 { - return atomic.LoadUint64(&c.value) -} diff --git a/pkg/config/config.go b/pkg/config/config.go deleted file mode 100644 index 55df4ad9c..000000000 --- a/pkg/config/config.go +++ /dev/null @@ -1,78 +0,0 @@ -package config - -import ( - stderrors "errors" - "fmt" - "github.com/creasty/defaults" - "github.com/goccy/go-yaml" - "github.com/jessevdk/go-flags" - "github.com/pkg/errors" - "os" - "reflect" -) - -// ErrInvalidArgument is the error returned by [ParseFlags] or [FromYAMLFile] if -// its parsing result cannot be stored in the value pointed to by the designated passed argument which -// must be a non-nil pointer. -var ErrInvalidArgument = stderrors.New("invalid argument") - -// FromYAMLFile parses the given YAML file and stores the result -// in the value pointed to by v. If v is nil or not a pointer, -// FromYAMLFile returns an [ErrInvalidArgument] error. -func FromYAMLFile(name string, v Validator) error { - rv := reflect.ValueOf(v) - if rv.Kind() != reflect.Pointer || rv.IsNil() { - return errors.Wrapf(ErrInvalidArgument, "non-nil pointer expected, got %T", v) - } - - f, err := os.Open(name) - if err != nil { - return errors.Wrap(err, "can't open YAML file "+name) - } - defer func(f *os.File) { - _ = f.Close() - }(f) - - if err := defaults.Set(v); err != nil { - return errors.Wrap(err, "can't set config defaults") - } - - d := yaml.NewDecoder(f, yaml.DisallowUnknownField()) - if err := d.Decode(v); err != nil { - return errors.Wrap(err, "can't parse YAML file "+name) - } - - if err := v.Validate(); err != nil { - return errors.Wrap(err, "invalid configuration") - } - - return nil -} - -// ParseFlags parses CLI flags and stores the result -// in the value pointed to by v. If v is nil or not a pointer, -// ParseFlags returns an [ErrInvalidArgument] error. -// ParseFlags adds a default Help Options group, -// which contains the options -h and --help. -// If either option is specified on the command line, -// ParseFlags prints the help message to [os.Stdout] and exits. -func ParseFlags(v any) error { - rv := reflect.ValueOf(v) - if rv.Kind() != reflect.Pointer || rv.IsNil() { - return errors.Wrapf(ErrInvalidArgument, "non-nil pointer expected, got %T", v) - } - - parser := flags.NewParser(v, flags.Default^flags.PrintErrors) - - if _, err := parser.Parse(); err != nil { - var flagErr *flags.Error - if errors.As(err, &flagErr) && flagErr.Type == flags.ErrHelp { - fmt.Fprintln(os.Stdout, flagErr) - os.Exit(0) - } - - return errors.Wrap(err, "can't parse CLI flags") - } - - return nil -} diff --git a/pkg/config/contracts.go b/pkg/config/contracts.go deleted file mode 100644 index 760229914..000000000 --- a/pkg/config/contracts.go +++ /dev/null @@ -1,5 +0,0 @@ -package config - -type Validator interface { - Validate() error -} diff --git a/pkg/config/tls.go b/pkg/config/tls.go deleted file mode 100644 index 3fc6f649c..000000000 --- a/pkg/config/tls.go +++ /dev/null @@ -1,58 +0,0 @@ -package config - -import ( - "crypto/tls" - "crypto/x509" - "github.com/pkg/errors" - "os" -) - -// TLS provides TLS configuration options. -type TLS struct { - Enable bool `yaml:"tls"` - Cert string `yaml:"cert"` - Key string `yaml:"key"` - Ca string `yaml:"ca"` - Insecure bool `yaml:"insecure"` -} - -// MakeConfig assembles a tls.Config from t and serverName. -func (t *TLS) MakeConfig(serverName string) (*tls.Config, error) { - if !t.Enable { - return nil, nil - } - - tlsConfig := &tls.Config{} - if t.Cert == "" { - if t.Key != "" { - return nil, errors.New("private key given, but client certificate missing") - } - } else if t.Key == "" { - return nil, errors.New("client certificate given, but private key missing") - } else { - crt, err := tls.LoadX509KeyPair(t.Cert, t.Key) - if err != nil { - return nil, errors.Wrap(err, "can't load X.509 key pair") - } - - tlsConfig.Certificates = []tls.Certificate{crt} - } - - if t.Insecure { - tlsConfig.InsecureSkipVerify = true - } else if t.Ca != "" { - raw, err := os.ReadFile(t.Ca) - if err != nil { - return nil, errors.Wrap(err, "can't read CA file") - } - - tlsConfig.RootCAs = x509.NewCertPool() - if !tlsConfig.RootCAs.AppendCertsFromPEM(raw) { - return nil, errors.New("can't parse CA file") - } - } - - tlsConfig.ServerName = serverName - - return tlsConfig, nil -} diff --git a/pkg/database/column_map.go b/pkg/database/column_map.go deleted file mode 100644 index 5642c841e..000000000 --- a/pkg/database/column_map.go +++ /dev/null @@ -1,75 +0,0 @@ -package database - -import ( - "database/sql/driver" - "github.com/jmoiron/sqlx/reflectx" - "reflect" - "sync" -) - -// ColumnMap provides a cached mapping of structs exported fields to their database column names. -type ColumnMap interface { - // Columns returns database column names for a struct's exported fields in a cached manner. - // Thus, the returned slice MUST NOT be modified directly. - // By default, all exported struct fields are mapped to database column names using snake case notation. - // The - (hyphen) directive for the db tag can be used to exclude certain fields. - Columns(any) []string -} - -// NewColumnMap returns a new ColumnMap. -func NewColumnMap(mapper *reflectx.Mapper) ColumnMap { - return &columnMap{ - cache: make(map[reflect.Type][]string), - mapper: mapper, - } -} - -type columnMap struct { - mutex sync.Mutex - cache map[reflect.Type][]string - mapper *reflectx.Mapper -} - -func (m *columnMap) Columns(subject any) []string { - m.mutex.Lock() - defer m.mutex.Unlock() - - t, ok := subject.(reflect.Type) - if !ok { - t = reflect.TypeOf(subject) - } - - columns, ok := m.cache[t] - if !ok { - columns = m.getColumns(t) - m.cache[t] = columns - } - - return columns -} - -func (m *columnMap) getColumns(t reflect.Type) []string { - fields := m.mapper.TypeMap(t).Names - columns := make([]string, 0, len(fields)) - -FieldLoop: - for _, f := range fields { - // If one of the parent fields implements the driver.Valuer interface, the field can be ignored. - for parent := f.Parent; parent != nil && parent.Zero.IsValid(); parent = parent.Parent { - // Check for pointer types. - if _, ok := reflect.New(parent.Field.Type).Interface().(driver.Valuer); ok { - continue FieldLoop - } - // Check for non-pointer types. - if _, ok := reflect.Zero(parent.Field.Type).Interface().(driver.Valuer); ok { - continue FieldLoop - } - } - - columns = append(columns, f.Path) - } - - // Shrink/reduce slice length and capacity: - // For a three-index slice (slice[a:b:c]), the length of the returned slice is b-a and the capacity is c-a. - return columns[0:len(columns):len(columns)] -} diff --git a/pkg/database/config.go b/pkg/database/config.go deleted file mode 100644 index bfcf299e4..000000000 --- a/pkg/database/config.go +++ /dev/null @@ -1,45 +0,0 @@ -package database - -import ( - "github.com/icinga/icinga-go-library/config" - "github.com/pkg/errors" -) - -// Config defines database client configuration. -type Config struct { - Type string `yaml:"type" default:"mysql"` - Host string `yaml:"host"` - Port int `yaml:"port"` - Database string `yaml:"database"` - User string `yaml:"user"` - Password string `yaml:"password"` - TlsOptions config.TLS `yaml:",inline"` - Options Options `yaml:"options"` -} - -// Validate checks constraints in the supplied database configuration and returns an error if they are violated. -func (c *Config) Validate() error { - switch c.Type { - case "mysql", "pgsql": - default: - return unknownDbType(c.Type) - } - - if c.Host == "" { - return errors.New("database host missing") - } - - if c.User == "" { - return errors.New("database user missing") - } - - if c.Database == "" { - return errors.New("database name missing") - } - - return c.Options.Validate() -} - -func unknownDbType(t string) error { - return errors.Errorf(`unknown database type %q, must be one of: "mysql", "pgsql"`, t) -} diff --git a/pkg/database/contracts.go b/pkg/database/contracts.go deleted file mode 100644 index bf55d3207..000000000 --- a/pkg/database/contracts.go +++ /dev/null @@ -1,56 +0,0 @@ -package database - -// Entity is implemented by each type that works with the database package. -type Entity interface { - Fingerprinter - IDer -} - -// Fingerprinter is implemented by every entity that uniquely identifies itself. -type Fingerprinter interface { - // Fingerprint returns the value that uniquely identifies the entity. - Fingerprint() Fingerprinter -} - -// ID is a unique identifier of an entity. -type ID interface { - // String returns the string representation form of the ID. - // The String method is used to use the ID in functions - // where it needs to be compared or hashed. - String() string -} - -// IDer is implemented by every entity that uniquely identifies itself. -type IDer interface { - ID() ID // ID returns the ID. - SetID(ID) // SetID sets the ID. -} - -// EntityFactoryFunc knows how to create an Entity. -type EntityFactoryFunc func() Entity - -// Upserter implements the Upsert method, -// which returns a part of the object for ON DUPLICATE KEY UPDATE. -type Upserter interface { - Upsert() any // Upsert partitions the object. -} - -// TableNamer implements the TableName method, -// which returns the table of the object. -type TableNamer interface { - TableName() string // TableName tells the table. -} - -// Scoper implements the Scope method, -// which returns a struct specifying the WHERE conditions that -// entities must satisfy in order to be SELECTed. -type Scoper interface { - Scope() any -} - -// PgsqlOnConflictConstrainter implements the PgsqlOnConflictConstraint method, -// which returns the primary or unique key constraint name of the PostgreSQL table. -type PgsqlOnConflictConstrainter interface { - // PgsqlOnConflictConstraint returns the primary or unique key constraint name of the PostgreSQL table. - PgsqlOnConflictConstraint() string -} diff --git a/pkg/database/db.go b/pkg/database/db.go deleted file mode 100644 index 3eb2df76f..000000000 --- a/pkg/database/db.go +++ /dev/null @@ -1,800 +0,0 @@ -package database - -import ( - "context" - "database/sql" - "database/sql/driver" - "fmt" - "github.com/go-sql-driver/mysql" - "github.com/icinga/icinga-go-library/backoff" - "github.com/icinga/icinga-go-library/com" - "github.com/icinga/icinga-go-library/logging" - "github.com/icinga/icinga-go-library/periodic" - "github.com/icinga/icinga-go-library/retry" - "github.com/icinga/icinga-go-library/strcase" - "github.com/icinga/icinga-go-library/utils" - "github.com/jmoiron/sqlx" - "github.com/jmoiron/sqlx/reflectx" - "github.com/lib/pq" - "github.com/pkg/errors" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" - "golang.org/x/sync/semaphore" - "net" - "net/url" - "strconv" - "strings" - "sync" - "time" -) - -// DB is a wrapper around sqlx.DB with bulk execution, -// statement building, streaming and logging capabilities. -type DB struct { - *sqlx.DB - - Options *Options - - addr string - columnMap ColumnMap - logger *logging.Logger - tableSemaphores map[string]*semaphore.Weighted - tableSemaphoresMu sync.Mutex -} - -// Options define user configurable database options. -type Options struct { - // Maximum number of open connections to the database. - MaxConnections int `yaml:"max_connections" default:"16"` - - // Maximum number of connections per table, - // regardless of what the connection is actually doing, - // e.g. INSERT, UPDATE, DELETE. - MaxConnectionsPerTable int `yaml:"max_connections_per_table" default:"8"` - - // MaxPlaceholdersPerStatement defines the maximum number of placeholders in an - // INSERT, UPDATE or DELETE statement. Theoretically, MySQL can handle up to 2^16-1 placeholders, - // but this increases the execution time of queries and thus reduces the number of queries - // that can be executed in parallel in a given time. - // The default is 2^13, which in our tests showed the best performance in terms of execution time and parallelism. - MaxPlaceholdersPerStatement int `yaml:"max_placeholders_per_statement" default:"8192"` - - // MaxRowsPerTransaction defines the maximum number of rows per transaction. - // The default is 2^13, which in our tests showed the best performance in terms of execution time and parallelism. - MaxRowsPerTransaction int `yaml:"max_rows_per_transaction" default:"8192"` - - // WsrepSyncWait enforces Galera cluster nodes to perform strict cluster-wide causality checks - // before executing specific SQL queries determined by the number you provided. - // Please refer to the below link for a detailed description. - // https://icinga.com/docs/icinga-db/latest/doc/03-Configuration/#galera-cluster - WsrepSyncWait int `yaml:"wsrep_sync_wait" default:"7"` -} - -// Validate checks constraints in the supplied database options and returns an error if they are violated. -func (o *Options) Validate() error { - if o.MaxConnections == 0 { - return errors.New("max_connections cannot be 0. Configure a value greater than zero, or use -1 for no connection limit") - } - if o.MaxConnectionsPerTable < 1 { - return errors.New("max_connections_per_table must be at least 1") - } - if o.MaxPlaceholdersPerStatement < 1 { - return errors.New("max_placeholders_per_statement must be at least 1") - } - if o.MaxRowsPerTransaction < 1 { - return errors.New("max_rows_per_transaction must be at least 1") - } - if o.WsrepSyncWait < 0 || o.WsrepSyncWait > 15 { - return errors.New("wsrep_sync_wait can only be set to a number between 0 and 15") - } - - return nil -} - -// NewDbFromConfig returns a new DB from Config. -func NewDbFromConfig(c *Config, logger *logging.Logger, connectorCallbacks RetryConnectorCallbacks) (*DB, error) { - var addr string - var db *sqlx.DB - - switch c.Type { - case "mysql": - config := mysql.NewConfig() - - config.User = c.User - config.Passwd = c.Password - config.Logger = MysqlFuncLogger(logger.Debug) - - if utils.IsUnixAddr(c.Host) { - config.Net = "unix" - config.Addr = c.Host - } else { - config.Net = "tcp" - port := c.Port - if port == 0 { - port = 3306 - } - config.Addr = net.JoinHostPort(c.Host, fmt.Sprint(port)) - } - - config.DBName = c.Database - config.Timeout = time.Minute - config.Params = map[string]string{"sql_mode": "'TRADITIONAL,ANSI_QUOTES'"} - - tlsConfig, err := c.TlsOptions.MakeConfig(c.Host) - if err != nil { - return nil, err - } - - config.TLS = tlsConfig - - connector, err := mysql.NewConnector(config) - if err != nil { - return nil, errors.Wrap(err, "can't open mysql database") - } - - onInitConn := connectorCallbacks.OnInitConn - connectorCallbacks.OnInitConn = func(ctx context.Context, conn driver.Conn) error { - if onInitConn != nil { - if err := onInitConn(ctx, conn); err != nil { - return err - } - } - - return setGaleraOpts(ctx, conn, int64(c.Options.WsrepSyncWait)) - } - - addr = config.Addr - db = sqlx.NewDb(sql.OpenDB(NewConnector(connector, logger, connectorCallbacks)), MySQL) - case "pgsql": - uri := &url.URL{ - Scheme: "postgres", - User: url.UserPassword(c.User, c.Password), - Path: "/" + url.PathEscape(c.Database), - } - - query := url.Values{ - "connect_timeout": {"60"}, - "binary_parameters": {"yes"}, - - // Host and port can alternatively be specified in the query string. lib/pq can't parse the connection URI - // if a Unix domain socket path is specified in the host part of the URI, therefore always use the query - // string. See also https://github.com/lib/pq/issues/796 - "host": {c.Host}, - } - - port := c.Port - if port == 0 { - port = 5432 - } - query["port"] = []string{strconv.FormatInt(int64(port), 10)} - - if _, err := c.TlsOptions.MakeConfig(c.Host); err != nil { - return nil, err - } - - if c.TlsOptions.Enable { - if c.TlsOptions.Insecure { - query["sslmode"] = []string{"require"} - } else { - query["sslmode"] = []string{"verify-full"} - } - - if c.TlsOptions.Cert != "" { - query["sslcert"] = []string{c.TlsOptions.Cert} - } - - if c.TlsOptions.Key != "" { - query["sslkey"] = []string{c.TlsOptions.Key} - } - - if c.TlsOptions.Ca != "" { - query["sslrootcert"] = []string{c.TlsOptions.Ca} - } - } else { - query["sslmode"] = []string{"disable"} - } - - uri.RawQuery = query.Encode() - - connector, err := pq.NewConnector(uri.String()) - if err != nil { - return nil, errors.Wrap(err, "can't open pgsql database") - } - - addr = utils.JoinHostPort(c.Host, port) - db = sqlx.NewDb(sql.OpenDB(NewConnector(connector, logger, connectorCallbacks)), PostgreSQL) - default: - return nil, unknownDbType(c.Type) - } - - db.SetMaxIdleConns(c.Options.MaxConnections / 3) - db.SetMaxOpenConns(c.Options.MaxConnections) - - db.Mapper = reflectx.NewMapperFunc("db", strcase.Snake) - - return &DB{ - DB: db, - Options: &c.Options, - columnMap: NewColumnMap(db.Mapper), - addr: addr, - logger: logger, - tableSemaphores: make(map[string]*semaphore.Weighted), - }, nil -} - -// GetAddr returns the database host:port or Unix socket address. -func (db *DB) GetAddr() string { - return db.addr -} - -// BuildDeleteStmt returns a DELETE statement for the given struct. -func (db *DB) BuildDeleteStmt(from interface{}) string { - return fmt.Sprintf( - `DELETE FROM "%s" WHERE id IN (?)`, - TableName(from), - ) -} - -// BuildInsertStmt returns an INSERT INTO statement for the given struct. -func (db *DB) BuildInsertStmt(into interface{}) (string, int) { - columns := db.columnMap.Columns(into) - - return fmt.Sprintf( - `INSERT INTO "%s" ("%s") VALUES (%s)`, - TableName(into), - strings.Join(columns, `", "`), - fmt.Sprintf(":%s", strings.Join(columns, ", :")), - ), len(columns) -} - -// BuildInsertIgnoreStmt returns an INSERT statement for the specified struct for -// which the database ignores rows that have already been inserted. -func (db *DB) BuildInsertIgnoreStmt(into interface{}) (string, int) { - table := TableName(into) - columns := db.columnMap.Columns(into) - var clause string - - switch db.DriverName() { - case MySQL: - // MySQL treats UPDATE id = id as a no-op. - clause = fmt.Sprintf(`ON DUPLICATE KEY UPDATE "%s" = "%s"`, columns[0], columns[0]) - case PostgreSQL: - var constraint string - if constrainter, ok := into.(PgsqlOnConflictConstrainter); ok { - constraint = constrainter.PgsqlOnConflictConstraint() - } else { - constraint = "pk_" + table - } - - clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO NOTHING", constraint) - } - - return fmt.Sprintf( - `INSERT INTO "%s" ("%s") VALUES (%s) %s`, - table, - strings.Join(columns, `", "`), - fmt.Sprintf(":%s", strings.Join(columns, ", :")), - clause, - ), len(columns) -} - -// BuildSelectStmt returns a SELECT query that creates the FROM part from the given table struct -// and the column list from the specified columns struct. -func (db *DB) BuildSelectStmt(table interface{}, columns interface{}) string { - q := fmt.Sprintf( - `SELECT "%s" FROM "%s"`, - strings.Join(db.columnMap.Columns(columns), `", "`), - TableName(table), - ) - - if scoper, ok := table.(Scoper); ok { - where, _ := db.BuildWhere(scoper.Scope()) - q += ` WHERE ` + where - } - - return q -} - -// BuildUpdateStmt returns an UPDATE statement for the given struct. -func (db *DB) BuildUpdateStmt(update interface{}) (string, int) { - columns := db.columnMap.Columns(update) - set := make([]string, 0, len(columns)) - - for _, col := range columns { - set = append(set, fmt.Sprintf(`"%s" = :%s`, col, col)) - } - - return fmt.Sprintf( - `UPDATE "%s" SET %s WHERE id = :id`, - TableName(update), - strings.Join(set, ", "), - ), len(columns) + 1 // +1 because of WHERE id = :id -} - -// BuildUpsertStmt returns an upsert statement for the given struct. -func (db *DB) BuildUpsertStmt(subject interface{}) (stmt string, placeholders int) { - insertColumns := db.columnMap.Columns(subject) - table := TableName(subject) - var updateColumns []string - - if upserter, ok := subject.(Upserter); ok { - updateColumns = db.columnMap.Columns(upserter.Upsert()) - } else { - updateColumns = insertColumns - } - - var clause, setFormat string - switch db.DriverName() { - case MySQL: - clause = "ON DUPLICATE KEY UPDATE" - setFormat = `"%[1]s" = VALUES("%[1]s")` - case PostgreSQL: - var constraint string - if constrainter, ok := subject.(PgsqlOnConflictConstrainter); ok { - constraint = constrainter.PgsqlOnConflictConstraint() - } else { - constraint = "pk_" + table - } - - clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO UPDATE SET", constraint) - setFormat = `"%[1]s" = EXCLUDED."%[1]s"` - } - - set := make([]string, 0, len(updateColumns)) - - for _, col := range updateColumns { - set = append(set, fmt.Sprintf(setFormat, col)) - } - - return fmt.Sprintf( - `INSERT INTO "%s" ("%s") VALUES (%s) %s %s`, - table, - strings.Join(insertColumns, `", "`), - fmt.Sprintf(":%s", strings.Join(insertColumns, ",:")), - clause, - strings.Join(set, ","), - ), len(insertColumns) -} - -// BuildWhere returns a WHERE clause with named placeholder conditions built from the specified struct -// combined with the AND operator. -func (db *DB) BuildWhere(subject interface{}) (string, int) { - columns := db.columnMap.Columns(subject) - where := make([]string, 0, len(columns)) - for _, col := range columns { - where = append(where, fmt.Sprintf(`"%s" = :%s`, col, col)) - } - - return strings.Join(where, ` AND `), len(columns) -} - -// OnSuccess is a callback for successful (bulk) DML operations. -type OnSuccess[T any] func(ctx context.Context, affectedRows []T) (err error) - -func OnSuccessIncrement[T any](counter *com.Counter) OnSuccess[T] { - return func(_ context.Context, rows []T) error { - counter.Add(uint64(len(rows))) - return nil - } -} - -func OnSuccessSendTo[T any](ch chan<- T) OnSuccess[T] { - return func(ctx context.Context, rows []T) error { - for _, row := range rows { - select { - case ch <- row: - case <-ctx.Done(): - return ctx.Err() - } - } - - return nil - } -} - -// BulkExec bulk executes queries with a single slice placeholder in the form of `IN (?)`. -// Takes in up to the number of arguments specified in count from the arg stream, -// derives and expands a query and executes it with this set of arguments until the arg stream has been processed. -// The derived queries are executed in a separate goroutine with a weighting of 1 -// and can be executed concurrently to the extent allowed by the semaphore passed in sem. -// Arguments for which the query ran successfully will be passed to onSuccess. -func (db *DB) BulkExec( - ctx context.Context, query string, count int, sem *semaphore.Weighted, arg <-chan any, onSuccess ...OnSuccess[any], -) error { - var counter com.Counter - defer db.Log(ctx, query, &counter).Stop() - - g, ctx := errgroup.WithContext(ctx) - // Use context from group. - bulk := com.Bulk(ctx, arg, count, com.NeverSplit[any]) - - g.Go(func() error { - g, ctx := errgroup.WithContext(ctx) - - for b := range bulk { - if err := sem.Acquire(ctx, 1); err != nil { - return errors.Wrap(err, "can't acquire semaphore") - } - - g.Go(func(b []interface{}) func() error { - return func() error { - defer sem.Release(1) - - return retry.WithBackoff( - ctx, - func(context.Context) error { - stmt, args, err := sqlx.In(query, b) - if err != nil { - return errors.Wrapf(err, "can't build placeholders for %q", query) - } - - stmt = db.Rebind(stmt) - _, err = db.ExecContext(ctx, stmt, args...) - if err != nil { - return CantPerformQuery(err, query) - } - - counter.Add(uint64(len(b))) - - for _, onSuccess := range onSuccess { - if err := onSuccess(ctx, b); err != nil { - return err - } - } - - return nil - }, - retry.Retryable, - backoff.NewExponentialWithJitter(1*time.Millisecond, 1*time.Second), - db.GetDefaultRetrySettings(), - ) - } - }(b)) - } - - return g.Wait() - }) - - return g.Wait() -} - -// NamedBulkExec bulk executes queries with named placeholders in a VALUES clause most likely -// in the format INSERT ... VALUES. Takes in up to the number of entities specified in count -// from the arg stream, derives and executes a new query with the VALUES clause expanded to -// this set of arguments, until the arg stream has been processed. -// The queries are executed in a separate goroutine with a weighting of 1 -// and can be executed concurrently to the extent allowed by the semaphore passed in sem. -// Entities for which the query ran successfully will be passed to onSuccess. -func (db *DB) NamedBulkExec( - ctx context.Context, query string, count int, sem *semaphore.Weighted, arg <-chan Entity, - splitPolicyFactory com.BulkChunkSplitPolicyFactory[Entity], onSuccess ...OnSuccess[Entity], -) error { - var counter com.Counter - defer db.Log(ctx, query, &counter).Stop() - - g, ctx := errgroup.WithContext(ctx) - bulk := com.Bulk(ctx, arg, count, splitPolicyFactory) - - g.Go(func() error { - for { - select { - case b, ok := <-bulk: - if !ok { - return nil - } - - if err := sem.Acquire(ctx, 1); err != nil { - return errors.Wrap(err, "can't acquire semaphore") - } - - g.Go(func(b []Entity) func() error { - return func() error { - defer sem.Release(1) - - return retry.WithBackoff( - ctx, - func(ctx context.Context) error { - _, err := db.NamedExecContext(ctx, query, b) - if err != nil { - return CantPerformQuery(err, query) - } - - counter.Add(uint64(len(b))) - - for _, onSuccess := range onSuccess { - if err := onSuccess(ctx, b); err != nil { - return err - } - } - - return nil - }, - retry.Retryable, - backoff.NewExponentialWithJitter(1*time.Millisecond, 1*time.Second), - db.GetDefaultRetrySettings(), - ) - } - }(b)) - case <-ctx.Done(): - return ctx.Err() - } - } - }) - - return g.Wait() -} - -// NamedBulkExecTx bulk executes queries with named placeholders in separate transactions. -// Takes in up to the number of entities specified in count from the arg stream and -// executes a new transaction that runs a new query for each entity in this set of arguments, -// until the arg stream has been processed. -// The transactions are executed in a separate goroutine with a weighting of 1 -// and can be executed concurrently to the extent allowed by the semaphore passed in sem. -func (db *DB) NamedBulkExecTx( - ctx context.Context, query string, count int, sem *semaphore.Weighted, arg <-chan Entity, -) error { - var counter com.Counter - defer db.Log(ctx, query, &counter).Stop() - - g, ctx := errgroup.WithContext(ctx) - bulk := com.Bulk(ctx, arg, count, com.NeverSplit[Entity]) - - g.Go(func() error { - for { - select { - case b, ok := <-bulk: - if !ok { - return nil - } - - if err := sem.Acquire(ctx, 1); err != nil { - return errors.Wrap(err, "can't acquire semaphore") - } - - g.Go(func(b []Entity) func() error { - return func() error { - defer sem.Release(1) - - return retry.WithBackoff( - ctx, - func(ctx context.Context) error { - tx, err := db.BeginTxx(ctx, nil) - if err != nil { - return errors.Wrap(err, "can't start transaction") - } - - stmt, err := tx.PrepareNamedContext(ctx, query) - if err != nil { - return errors.Wrap(err, "can't prepare named statement with context in transaction") - } - - for _, arg := range b { - if _, err := stmt.ExecContext(ctx, arg); err != nil { - return errors.Wrap(err, "can't execute statement in transaction") - } - } - - if err := tx.Commit(); err != nil { - return errors.Wrap(err, "can't commit transaction") - } - - counter.Add(uint64(len(b))) - - return nil - }, - retry.Retryable, - backoff.NewExponentialWithJitter(1*time.Millisecond, 1*time.Second), - db.GetDefaultRetrySettings(), - ) - } - }(b)) - case <-ctx.Done(): - return ctx.Err() - } - } - }) - - return g.Wait() -} - -// BatchSizeByPlaceholders returns how often the specified number of placeholders fits -// into Options.MaxPlaceholdersPerStatement, but at least 1. -func (db *DB) BatchSizeByPlaceholders(n int) int { - s := db.Options.MaxPlaceholdersPerStatement / n - if s > 0 { - return s - } - - return 1 -} - -// YieldAll executes the query with the supplied scope, -// scans each resulting row into an entity returned by the factory function, -// and streams them into a returned channel. -func (db *DB) YieldAll(ctx context.Context, factoryFunc EntityFactoryFunc, query string, scope interface{}) (<-chan Entity, <-chan error) { - entities := make(chan Entity, 1) - g, ctx := errgroup.WithContext(ctx) - - g.Go(func() error { - var counter com.Counter - defer db.Log(ctx, query, &counter).Stop() - defer close(entities) - - rows, err := db.NamedQueryContext(ctx, query, scope) - if err != nil { - return CantPerformQuery(err, query) - } - defer rows.Close() - - for rows.Next() { - e := factoryFunc() - - if err := rows.StructScan(e); err != nil { - return errors.Wrapf(err, "can't store query result into a %T: %s", e, query) - } - - select { - case entities <- e: - counter.Inc() - case <-ctx.Done(): - return ctx.Err() - } - } - - return nil - }) - - return entities, com.WaitAsync(g) -} - -// CreateStreamed bulk creates the specified entities via NamedBulkExec. -// The insert statement is created using BuildInsertStmt with the first entity from the entities stream. -// Bulk size is controlled via Options.MaxPlaceholdersPerStatement and -// concurrency is controlled via Options.MaxConnectionsPerTable. -// Entities for which the query ran successfully will be passed to onSuccess. -func (db *DB) CreateStreamed( - ctx context.Context, entities <-chan Entity, onSuccess ...OnSuccess[Entity], -) error { - first, forward, err := com.CopyFirst(ctx, entities) - if err != nil { - return errors.Wrap(err, "can't copy first entity") - } - - sem := db.GetSemaphoreForTable(TableName(first)) - stmt, placeholders := db.BuildInsertStmt(first) - - return db.NamedBulkExec( - ctx, stmt, db.BatchSizeByPlaceholders(placeholders), sem, - forward, com.NeverSplit[Entity], onSuccess..., - ) -} - -// CreateIgnoreStreamed bulk creates the specified entities via NamedBulkExec. -// The insert statement is created using BuildInsertIgnoreStmt with the first entity from the entities stream. -// Bulk size is controlled via Options.MaxPlaceholdersPerStatement and -// concurrency is controlled via Options.MaxConnectionsPerTable. -// Entities for which the query ran successfully will be passed to onSuccess. -func (db *DB) CreateIgnoreStreamed( - ctx context.Context, entities <-chan Entity, onSuccess ...OnSuccess[Entity], -) error { - first, forward, err := com.CopyFirst(ctx, entities) - if err != nil { - return errors.Wrap(err, "can't copy first entity") - } - - sem := db.GetSemaphoreForTable(TableName(first)) - stmt, placeholders := db.BuildInsertIgnoreStmt(first) - - return db.NamedBulkExec( - ctx, stmt, db.BatchSizeByPlaceholders(placeholders), sem, - forward, SplitOnDupId[Entity], onSuccess..., - ) -} - -// UpsertStreamed bulk upserts the specified entities via NamedBulkExec. -// The upsert statement is created using BuildUpsertStmt with the first entity from the entities stream. -// Bulk size is controlled via Options.MaxPlaceholdersPerStatement and -// concurrency is controlled via Options.MaxConnectionsPerTable. -// Entities for which the query ran successfully will be passed to onSuccess. -func (db *DB) UpsertStreamed( - ctx context.Context, entities <-chan Entity, onSuccess ...OnSuccess[Entity], -) error { - first, forward, err := com.CopyFirst(ctx, entities) - if err != nil { - return errors.Wrap(err, "can't copy first entity") - } - - sem := db.GetSemaphoreForTable(TableName(first)) - stmt, placeholders := db.BuildUpsertStmt(first) - - return db.NamedBulkExec( - ctx, stmt, db.BatchSizeByPlaceholders(placeholders), sem, - forward, SplitOnDupId[Entity], onSuccess..., - ) -} - -// UpdateStreamed bulk updates the specified entities via NamedBulkExecTx. -// The update statement is created using BuildUpdateStmt with the first entity from the entities stream. -// Bulk size is controlled via Options.MaxRowsPerTransaction and -// concurrency is controlled via Options.MaxConnectionsPerTable. -func (db *DB) UpdateStreamed(ctx context.Context, entities <-chan Entity) error { - first, forward, err := com.CopyFirst(ctx, entities) - if err != nil { - return errors.Wrap(err, "can't copy first entity") - } - sem := db.GetSemaphoreForTable(TableName(first)) - stmt, _ := db.BuildUpdateStmt(first) - - return db.NamedBulkExecTx(ctx, stmt, db.Options.MaxRowsPerTransaction, sem, forward) -} - -// DeleteStreamed bulk deletes the specified ids via BulkExec. -// The delete statement is created using BuildDeleteStmt with the passed entityType. -// Bulk size is controlled via Options.MaxPlaceholdersPerStatement and -// concurrency is controlled via Options.MaxConnectionsPerTable. -// IDs for which the query ran successfully will be passed to onSuccess. -func (db *DB) DeleteStreamed( - ctx context.Context, entityType Entity, ids <-chan interface{}, onSuccess ...OnSuccess[any], -) error { - sem := db.GetSemaphoreForTable(TableName(entityType)) - return db.BulkExec( - ctx, db.BuildDeleteStmt(entityType), db.Options.MaxPlaceholdersPerStatement, sem, ids, onSuccess..., - ) -} - -// Delete creates a channel from the specified ids and -// bulk deletes them by passing the channel along with the entityType to DeleteStreamed. -// IDs for which the query ran successfully will be passed to onSuccess. -func (db *DB) Delete( - ctx context.Context, entityType Entity, ids []interface{}, onSuccess ...OnSuccess[any], -) error { - idsCh := make(chan interface{}, len(ids)) - for _, id := range ids { - idsCh <- id - } - close(idsCh) - - return db.DeleteStreamed(ctx, entityType, idsCh, onSuccess...) -} - -func (db *DB) GetSemaphoreForTable(table string) *semaphore.Weighted { - db.tableSemaphoresMu.Lock() - defer db.tableSemaphoresMu.Unlock() - - if sem, ok := db.tableSemaphores[table]; ok { - return sem - } else { - sem = semaphore.NewWeighted(int64(db.Options.MaxConnectionsPerTable)) - db.tableSemaphores[table] = sem - return sem - } -} - -func (db *DB) GetDefaultRetrySettings() retry.Settings { - return retry.Settings{ - Timeout: retry.DefaultTimeout, - OnRetryableError: func(_ time.Duration, _ uint64, err, lastErr error) { - if lastErr == nil || err.Error() != lastErr.Error() { - db.logger.Warnw("Can't execute query. Retrying", zap.Error(err)) - } - }, - OnSuccess: func(elapsed time.Duration, attempt uint64, lastErr error) { - if attempt > 1 { - db.logger.Infow("Query retried successfully after error", - zap.Duration("after", elapsed), - zap.Uint64("attempts", attempt), - zap.NamedError("recovered_error", lastErr)) - } - }, - } -} - -func (db *DB) Log(ctx context.Context, query string, counter *com.Counter) periodic.Stopper { - return periodic.Start(ctx, db.logger.Interval(), func(tick periodic.Tick) { - if count := counter.Reset(); count > 0 { - db.logger.Debugf("Executed %q with %d rows", query, count) - } - }, periodic.OnStop(func(tick periodic.Tick) { - db.logger.Debugf("Finished executing %q with %d rows in %s", query, counter.Total(), tick.Elapsed) - })) -} diff --git a/pkg/database/driver.go b/pkg/database/driver.go deleted file mode 100644 index 5aa03aeac..000000000 --- a/pkg/database/driver.go +++ /dev/null @@ -1,100 +0,0 @@ -package database - -import ( - "context" - "database/sql/driver" - "github.com/icinga/icinga-go-library/backoff" - "github.com/icinga/icinga-go-library/logging" - "github.com/icinga/icinga-go-library/retry" - "github.com/pkg/errors" - "go.uber.org/zap" - "time" -) - -// Driver names as automatically registered in the database/sql package by themselves. -const ( - MySQL string = "mysql" - PostgreSQL string = "postgres" -) - -// OnInitConnFunc can be used to execute post Connect() arbitrary actions. -// It will be called after successfully initiated a new connection using the connector's Connect method. -type OnInitConnFunc func(context.Context, driver.Conn) error - -// RetryConnectorCallbacks specifies callbacks that are executed upon certain events. -type RetryConnectorCallbacks struct { - OnInitConn OnInitConnFunc - OnRetryableError retry.OnRetryableErrorFunc - OnSuccess retry.OnSuccessFunc -} - -// RetryConnector wraps driver.Connector with retry logic. -type RetryConnector struct { - driver.Connector - - logger *logging.Logger - - callbacks RetryConnectorCallbacks -} - -// NewConnector creates a fully initialized RetryConnector from the given args. -func NewConnector(c driver.Connector, logger *logging.Logger, callbacks RetryConnectorCallbacks) *RetryConnector { - return &RetryConnector{Connector: c, logger: logger, callbacks: callbacks} -} - -// Connect implements part of the driver.Connector interface. -func (c RetryConnector) Connect(ctx context.Context) (driver.Conn, error) { - var conn driver.Conn - err := errors.Wrap(retry.WithBackoff( - ctx, - func(ctx context.Context) (err error) { - conn, err = c.Connector.Connect(ctx) - if err == nil && c.callbacks.OnInitConn != nil { - if err = c.callbacks.OnInitConn(ctx, conn); err != nil { - // We're going to retry this, so just don't bother whether Close() fails! - _ = conn.Close() - } - } - - return - }, - retry.Retryable, - backoff.NewExponentialWithJitter(128*time.Millisecond, 1*time.Minute), - retry.Settings{ - Timeout: retry.DefaultTimeout, - OnRetryableError: func(elapsed time.Duration, attempt uint64, err, lastErr error) { - if c.callbacks.OnRetryableError != nil { - c.callbacks.OnRetryableError(elapsed, attempt, err, lastErr) - } - - if lastErr == nil || err.Error() != lastErr.Error() { - c.logger.Warnw("Can't connect to database. Retrying", zap.Error(err)) - } - }, - OnSuccess: func(elapsed time.Duration, attempt uint64, lastErr error) { - if c.callbacks.OnSuccess != nil { - c.callbacks.OnSuccess(elapsed, attempt, lastErr) - } - - if attempt > 1 { - c.logger.Infow("Reconnected to database", - zap.Duration("after", elapsed), zap.Uint64("attempts", attempt)) - } - }, - }, - ), "can't connect to database") - return conn, err -} - -// Driver implements part of the driver.Connector interface. -func (c RetryConnector) Driver() driver.Driver { - return c.Connector.Driver() -} - -// MysqlFuncLogger is an adapter that allows ordinary functions to be used as a logger for mysql.SetLogger. -type MysqlFuncLogger func(v ...interface{}) - -// Print implements the mysql.Logger interface. -func (log MysqlFuncLogger) Print(v ...interface{}) { - log(v) -} diff --git a/pkg/database/utils.go b/pkg/database/utils.go deleted file mode 100644 index 2ae372cec..000000000 --- a/pkg/database/utils.go +++ /dev/null @@ -1,81 +0,0 @@ -package database - -import ( - "context" - "database/sql/driver" - "github.com/go-sql-driver/mysql" - "github.com/icinga/icinga-go-library/com" - "github.com/icinga/icinga-go-library/strcase" - "github.com/icinga/icinga-go-library/types" - "github.com/pkg/errors" -) - -// CantPerformQuery wraps the given error with the specified query that cannot be executed. -func CantPerformQuery(err error, q string) error { - return errors.Wrapf(err, "can't perform %q", q) -} - -// TableName returns the table of t. -func TableName(t interface{}) string { - if tn, ok := t.(TableNamer); ok { - return tn.TableName() - } else { - return strcase.Snake(types.Name(t)) - } -} - -// SplitOnDupId returns a state machine which tracks the inputs' IDs. -// Once an already seen input arrives, it demands splitting. -func SplitOnDupId[T IDer]() com.BulkChunkSplitPolicy[T] { - seenIds := map[string]struct{}{} - - return func(ider T) bool { - id := ider.ID().String() - - _, ok := seenIds[id] - if ok { - seenIds = map[string]struct{}{id: {}} - } else { - seenIds[id] = struct{}{} - } - - return ok - } -} - -// setGaleraOpts sets the "wsrep_sync_wait" variable for each session ensures that causality checks are performed -// before execution and that each statement is executed on a fully synchronized node. Doing so prevents foreign key -// violation when inserting into dependent tables on different MariaDB/MySQL nodes. When using MySQL single nodes, -// the "SET SESSION" command will fail with "Unknown system variable (1193)" and will therefore be silently dropped. -// -// https://mariadb.com/kb/en/galera-cluster-system-variables/#wsrep_sync_wait -func setGaleraOpts(ctx context.Context, conn driver.Conn, wsrepSyncWait int64) error { - const galeraOpts = "SET SESSION wsrep_sync_wait=?" - - stmt, err := conn.(driver.ConnPrepareContext).PrepareContext(ctx, galeraOpts) - if err != nil { - if errors.Is(err, &mysql.MySQLError{Number: 1193}) { // Unknown system variable - return nil - } - - return errors.Wrap(err, "cannot prepare "+galeraOpts) - } - // This is just for an unexpected exit and any returned error can safely be ignored and in case - // of the normal function exit, the stmt is closed manually, and its error is handled gracefully. - defer func() { _ = stmt.Close() }() - - _, err = stmt.(driver.StmtExecContext).ExecContext(ctx, []driver.NamedValue{{Value: wsrepSyncWait}}) - if err != nil { - return errors.Wrap(err, "cannot execute "+galeraOpts) - } - - if err = stmt.Close(); err != nil { - return errors.Wrap(err, "cannot close prepared statement "+galeraOpts) - } - - return nil -} - -var ( - _ com.BulkChunkSplitPolicyFactory[Entity] = SplitOnDupId[Entity] -) diff --git a/pkg/flatten/flatten.go b/pkg/flatten/flatten.go deleted file mode 100644 index 6d3cee68d..000000000 --- a/pkg/flatten/flatten.go +++ /dev/null @@ -1,46 +0,0 @@ -package flatten - -import ( - "fmt" - "github.com/icinga/icinga-go-library/types" - "strconv" -) - -// Flatten creates flat, one-dimensional maps from arbitrarily nested values, e.g. JSON. -func Flatten(value interface{}, prefix string) map[string]types.String { - var flatten func(string, interface{}) - flattened := make(map[string]types.String) - - flatten = func(key string, value interface{}) { - switch value := value.(type) { - case map[string]interface{}: - if len(value) == 0 { - flattened[key] = types.String{} - break - } - - for k, v := range value { - flatten(key+"."+k, v) - } - case []interface{}: - if len(value) == 0 { - flattened[key] = types.String{} - break - } - - for i, v := range value { - flatten(key+"["+strconv.Itoa(i)+"]", v) - } - case nil: - flattened[key] = types.MakeString("null") - case float64: - flattened[key] = types.MakeString(strconv.FormatFloat(value, 'f', -1, 64)) - default: - flattened[key] = types.MakeString(fmt.Sprintf("%v", value)) - } - } - - flatten(prefix, value) - - return flattened -} diff --git a/pkg/flatten/flatten_test.go b/pkg/flatten/flatten_test.go deleted file mode 100644 index 99bf67a75..000000000 --- a/pkg/flatten/flatten_test.go +++ /dev/null @@ -1,45 +0,0 @@ -package flatten - -import ( - "github.com/icinga/icinga-go-library/types" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestFlatten(t *testing.T) { - for _, st := range []struct { - name string - prefix string - value any - output map[string]types.String - }{ - {"nil", "a", nil, map[string]types.String{"a": types.MakeString("null")}}, - {"bool", "b", true, map[string]types.String{"b": types.MakeString("true")}}, - {"int", "c", 42, map[string]types.String{"c": types.MakeString("42")}}, - {"float", "d", 77.7, map[string]types.String{"d": types.MakeString("77.7")}}, - {"large_float", "e", 1e23, map[string]types.String{"e": types.MakeString("100000000000000000000000")}}, - {"string", "f", "\x00", map[string]types.String{"f": types.MakeString("\x00")}}, - {"nil_slice", "g", []any(nil), map[string]types.String{"g": {}}}, - {"empty_slice", "h", []any{}, map[string]types.String{"h": {}}}, - {"slice", "i", []any{nil}, map[string]types.String{"i[0]": types.MakeString("null")}}, - {"nil_map", "j", map[string]any(nil), map[string]types.String{"j": {}}}, - {"empty_map", "k", map[string]any{}, map[string]types.String{"k": {}}}, - {"map", "l", map[string]any{" ": nil}, map[string]types.String{"l. ": types.MakeString("null")}}, - {"map_with_slice", "m", map[string]any{"\t": []any{"ä", "ö", "ü"}, "ß": "s"}, map[string]types.String{ - "m.\t[0]": types.MakeString("ä"), - "m.\t[1]": types.MakeString("ö"), - "m.\t[2]": types.MakeString("ü"), - "m.ß": types.MakeString("s"), - }}, - {"slice_with_map", "n", []any{map[string]any{"ä": "a", "ö": "o", "ü": "u"}, "ß"}, map[string]types.String{ - "n[0].ä": types.MakeString("a"), - "n[0].ö": types.MakeString("o"), - "n[0].ü": types.MakeString("u"), - "n[1]": types.MakeString("ß"), - }}, - } { - t.Run(st.name, func(t *testing.T) { - assert.Equal(t, st.output, Flatten(st.value, st.prefix)) - }) - } -} diff --git a/pkg/logging/config.go b/pkg/logging/config.go deleted file mode 100644 index 00eb14060..000000000 --- a/pkg/logging/config.go +++ /dev/null @@ -1,60 +0,0 @@ -package logging - -import ( - "fmt" - "github.com/pkg/errors" - "go.uber.org/zap/zapcore" - "os" - "time" -) - -// Options define child loggers with their desired log level. -type Options map[string]zapcore.Level - -// Config defines Logger configuration. -type Config struct { - // zapcore.Level at 0 is for info level. - Level zapcore.Level `yaml:"level" default:"0"` - Output string `yaml:"output"` - // Interval for periodic logging. - Interval time.Duration `yaml:"interval" default:"20s"` - - Options `yaml:"options"` -} - -// Validate checks constraints in the configuration and returns an error if they are violated. -// Also configures the log output if it is not configured: -// systemd-journald is used when Icinga DB is running under systemd, otherwise stderr. -func (l *Config) Validate() error { - if l.Interval <= 0 { - return errors.New("periodic logging interval must be positive") - } - - if l.Output == "" { - if _, ok := os.LookupEnv("NOTIFY_SOCKET"); ok { - // When started by systemd, NOTIFY_SOCKET is set by systemd for Type=notify supervised services, - // which is the default setting for the Icinga DB service. - // This assumes that Icinga DB is running under systemd, so set output to systemd-journald. - l.Output = JOURNAL - } else { - // Otherwise set it to console, i.e. write log messages to stderr. - l.Output = CONSOLE - } - } - - // To be on the safe side, always call AssertOutput. - return AssertOutput(l.Output) -} - -// AssertOutput returns an error if output is not a valid logger output. -func AssertOutput(o string) error { - if o == CONSOLE || o == JOURNAL { - return nil - } - - return invalidOutput(o) -} - -func invalidOutput(o string) error { - return fmt.Errorf("%s is not a valid logger output. Must be either %q or %q", o, CONSOLE, JOURNAL) -} diff --git a/pkg/logging/journald_core.go b/pkg/logging/journald_core.go deleted file mode 100644 index d1943a539..000000000 --- a/pkg/logging/journald_core.go +++ /dev/null @@ -1,84 +0,0 @@ -package logging - -import ( - "github.com/icinga/icinga-go-library/strcase" - "github.com/pkg/errors" - "github.com/ssgreg/journald" - "go.uber.org/zap/zapcore" - "strings" -) - -// priorities maps zapcore.Level to journal.Priority. -var priorities = map[zapcore.Level]journald.Priority{ - zapcore.DebugLevel: journald.PriorityDebug, - zapcore.InfoLevel: journald.PriorityInfo, - zapcore.WarnLevel: journald.PriorityWarning, - zapcore.ErrorLevel: journald.PriorityErr, - zapcore.FatalLevel: journald.PriorityCrit, - zapcore.PanicLevel: journald.PriorityCrit, - zapcore.DPanicLevel: journald.PriorityCrit, -} - -// NewJournaldCore returns a zapcore.Core that sends log entries to systemd-journald and -// uses the given identifier as a prefix for structured logging context that is sent as journal fields. -func NewJournaldCore(identifier string, enab zapcore.LevelEnabler) zapcore.Core { - return &journaldCore{ - LevelEnabler: enab, - identifier: identifier, - identifierU: strings.ToUpper(identifier), - } -} - -type journaldCore struct { - zapcore.LevelEnabler - context []zapcore.Field - identifier string - identifierU string -} - -func (c *journaldCore) Check(ent zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry { - if c.Enabled(ent.Level) { - return ce.AddCore(ent, c) - } - - return ce -} - -func (c *journaldCore) Sync() error { - return nil -} - -func (c *journaldCore) With(fields []zapcore.Field) zapcore.Core { - cc := *c - cc.context = append(cc.context[:len(cc.context):len(cc.context)], fields...) - - return &cc -} - -func (c *journaldCore) Write(ent zapcore.Entry, fields []zapcore.Field) error { - pri, ok := priorities[ent.Level] - if !ok { - return errors.Errorf("unknown log level %q", ent.Level) - } - - enc := zapcore.NewMapObjectEncoder() - c.addFields(enc, fields) - c.addFields(enc, c.context) - enc.Fields["SYSLOG_IDENTIFIER"] = c.identifier - - message := ent.Message - if ent.LoggerName != c.identifier { - message = ent.LoggerName + ": " + message - } - - return journald.Send(message, pri, enc.Fields) -} - -func (c *journaldCore) addFields(enc zapcore.ObjectEncoder, fields []zapcore.Field) { - for _, field := range fields { - field.Key = c.identifierU + - "_" + - strcase.ScreamingSnake(field.Key) - field.AddTo(enc) - } -} diff --git a/pkg/logging/logger.go b/pkg/logging/logger.go deleted file mode 100644 index 490445e17..000000000 --- a/pkg/logging/logger.go +++ /dev/null @@ -1,26 +0,0 @@ -package logging - -import ( - "go.uber.org/zap" - "time" -) - -// Logger wraps zap.SugaredLogger and -// allows to get the interval for periodic logging. -type Logger struct { - *zap.SugaredLogger - interval time.Duration -} - -// NewLogger returns a new Logger. -func NewLogger(base *zap.SugaredLogger, interval time.Duration) *Logger { - return &Logger{ - SugaredLogger: base, - interval: interval, - } -} - -// Interval returns the interval for periodic logging. -func (l *Logger) Interval() time.Duration { - return l.interval -} diff --git a/pkg/logging/logging.go b/pkg/logging/logging.go deleted file mode 100644 index f355da117..000000000 --- a/pkg/logging/logging.go +++ /dev/null @@ -1,119 +0,0 @@ -package logging - -import ( - "go.uber.org/zap" - "go.uber.org/zap/zapcore" - "os" - "sync" - "time" -) - -const ( - CONSOLE = "console" - JOURNAL = "systemd-journald" -) - -// defaultEncConfig defines the default zapcore.EncoderConfig for the logging package. -var defaultEncConfig = zapcore.EncoderConfig{ - TimeKey: "ts", - LevelKey: "level", - NameKey: "logger", - CallerKey: "caller", - MessageKey: "msg", - StacktraceKey: "stacktrace", - LineEnding: zapcore.DefaultLineEnding, - EncodeLevel: zapcore.CapitalLevelEncoder, - EncodeTime: zapcore.ISO8601TimeEncoder, - EncodeDuration: zapcore.StringDurationEncoder, - EncodeCaller: zapcore.ShortCallerEncoder, -} - -// Logging implements access to a default logger and named child loggers. -// Log levels can be configured per named child via Options which, if not configured, -// fall back on a default log level. -// Logs either to the console or to systemd-journald. -type Logging struct { - logger *Logger - output string - verbosity zap.AtomicLevel - interval time.Duration - - // coreFactory creates zapcore.Core based on the log level and the log output. - coreFactory func(zap.AtomicLevel) zapcore.Core - - mu sync.Mutex - loggers map[string]*Logger - - options Options -} - -// NewLogging takes the name and log level for the default logger, -// output where log messages are written to, -// options having log levels for named child loggers -// and returns a new Logging. -func NewLogging(name string, level zapcore.Level, output string, options Options, interval time.Duration) (*Logging, error) { - verbosity := zap.NewAtomicLevelAt(level) - - var coreFactory func(zap.AtomicLevel) zapcore.Core - switch output { - case CONSOLE: - enc := zapcore.NewConsoleEncoder(defaultEncConfig) - ws := zapcore.Lock(os.Stderr) - coreFactory = func(verbosity zap.AtomicLevel) zapcore.Core { - return zapcore.NewCore(enc, ws, verbosity) - } - case JOURNAL: - coreFactory = func(verbosity zap.AtomicLevel) zapcore.Core { - return NewJournaldCore(name, verbosity) - } - default: - return nil, invalidOutput(output) - } - - logger := NewLogger(zap.New(coreFactory(verbosity)).Named(name).Sugar(), interval) - - return &Logging{ - logger: logger, - output: output, - verbosity: verbosity, - interval: interval, - coreFactory: coreFactory, - loggers: make(map[string]*Logger), - options: options, - }, - nil -} - -// NewLoggingFromConfig returns a new Logging from Config. -func NewLoggingFromConfig(name string, c Config) (*Logging, error) { - return NewLogging(name, c.Level, c.Output, c.Options, c.Interval) -} - -// GetChildLogger returns a named child logger. -// Log levels for named child loggers are obtained from the logging options and, if not found, -// set to the default log level. -func (l *Logging) GetChildLogger(name string) *Logger { - l.mu.Lock() - defer l.mu.Unlock() - - if logger, ok := l.loggers[name]; ok { - return logger - } - - var verbosity zap.AtomicLevel - if level, found := l.options[name]; found { - verbosity = zap.NewAtomicLevelAt(level) - } else { - verbosity = l.verbosity - } - - logger := NewLogger(zap.New(l.coreFactory(verbosity)).Named(name).Sugar(), l.interval) - l.loggers[name] = logger - - return logger -} - -// GetLogger returns the default logger. -func (l *Logging) GetLogger() *Logger { - return l.logger -} diff --git a/pkg/objectpacker/objectpacker.go b/pkg/objectpacker/objectpacker.go deleted file mode 100644 index 015274599..000000000 --- a/pkg/objectpacker/objectpacker.go +++ /dev/null @@ -1,213 +0,0 @@ -package objectpacker - -import ( - "bytes" - "encoding/binary" - "fmt" - "github.com/pkg/errors" - "io" - "reflect" - "sort" -) - -// MustPackSlice calls PackAny using items and panics if there was an error. -func MustPackSlice(items ...interface{}) []byte { - var buf bytes.Buffer - - if err := PackAny(items, &buf); err != nil { - panic(err) - } - - return buf.Bytes() -} - -// PackAny packs any JSON-encodable value (ex. structs, also ignores interfaces like encoding.TextMarshaler) -// to a BSON-similar format suitable for consistent hashing. Spec: -// -// PackAny(nil) => 0x0 -// PackAny(false) => 0x1 -// PackAny(true) => 0x2 -// PackAny(float64(42)) => 0x3 ieee754_binary64_bigendian(42) -// PackAny("exämple") => 0x4 uint64_bigendian(len([]byte("exämple"))) []byte("exämple") -// PackAny([]uint8{0x42}) => 0x4 uint64_bigendian(len([]uint8{0x42})) []uint8{0x42} -// PackAny([1]uint8{0x42}) => 0x4 uint64_bigendian(len([1]uint8{0x42})) [1]uint8{0x42} -// PackAny([]T{x,y}) => 0x5 uint64_bigendian(len([]T{x,y})) PackAny(x) PackAny(y) -// PackAny(map[K]V{x:y}) => 0x6 uint64_bigendian(len(map[K]V{x:y})) len(map_key(x)) map_key(x) PackAny(y) -// PackAny((*T)(nil)) => 0x0 -// PackAny((*T)(0x42)) => PackAny(*(*T)(0x42)) -// PackAny(x) => panic() -// -// map_key([1]uint8{0x42}) => [1]uint8{0x42} -// map_key(x) => []byte(fmt.Sprint(x)) -func PackAny(in interface{}, out io.Writer) error { - return errors.Wrapf(packValue(reflect.ValueOf(in), out), "can't pack %#v", in) -} - -var tByte = reflect.TypeOf(byte(0)) -var tBytes = reflect.TypeOf([]uint8(nil)) - -// packValue does the actual job of packAny and just exists for recursion w/o unnecessary reflect.ValueOf calls. -func packValue(in reflect.Value, out io.Writer) error { - switch kind := in.Kind(); kind { - case reflect.Invalid: // nil - _, err := out.Write([]byte{0}) - return err - case reflect.Bool: - if in.Bool() { - _, err := out.Write([]byte{2}) - return err - } else { - _, err := out.Write([]byte{1}) - return err - } - case reflect.Float64: - if _, err := out.Write([]byte{3}); err != nil { - return err - } - - return binary.Write(out, binary.BigEndian, in.Float()) - case reflect.Array, reflect.Slice: - if typ := in.Type(); typ.Elem() == tByte { - if kind == reflect.Array { - if !in.CanAddr() { - vNewElem := reflect.New(typ).Elem() - vNewElem.Set(in) - in = vNewElem - } - - in = in.Slice(0, in.Len()) - } - - // Pack []byte as string, not array of numbers. - return packString(in.Convert(tBytes). // Support types.Binary - Interface().([]uint8), out) - } - - if _, err := out.Write([]byte{5}); err != nil { - return err - } - - l := in.Len() - if err := binary.Write(out, binary.BigEndian, uint64(l)); err != nil { - return err - } - - for i := 0; i < l; i++ { - if err := packValue(in.Index(i), out); err != nil { - return err - } - } - - // If there aren't any values to pack, ... - if l < 1 { - // ... create one and pack it - panics on disallowed type. - _ = packValue(reflect.Zero(in.Type().Elem()), io.Discard) - } - - return nil - case reflect.Interface: - return packValue(in.Elem(), out) - case reflect.Map: - type kv struct { - key []byte - value reflect.Value - } - - if _, err := out.Write([]byte{6}); err != nil { - return err - } - - l := in.Len() - if err := binary.Write(out, binary.BigEndian, uint64(l)); err != nil { - return err - } - - sorted := make([]kv, 0, l) - - { - iter := in.MapRange() - for iter.Next() { - var packedKey []byte - if key := iter.Key(); key.Kind() == reflect.Array { - if typ := key.Type(); typ.Elem() == tByte { - if !key.CanAddr() { - vNewElem := reflect.New(typ).Elem() - vNewElem.Set(key) - key = vNewElem - } - - packedKey = key.Slice(0, key.Len()).Interface().([]byte) - } else { - // Not just stringify the key (below), but also pack it (here) - panics on disallowed type. - _ = packValue(iter.Key(), io.Discard) - - packedKey = []byte(fmt.Sprint(key.Interface())) - } - } else { - // Not just stringify the key (below), but also pack it (here) - panics on disallowed type. - _ = packValue(iter.Key(), io.Discard) - - packedKey = []byte(fmt.Sprint(key.Interface())) - } - - sorted = append(sorted, kv{packedKey, iter.Value()}) - } - } - - sort.Slice(sorted, func(i, j int) bool { return bytes.Compare(sorted[i].key, sorted[j].key) < 0 }) - - for _, kv := range sorted { - if err := binary.Write(out, binary.BigEndian, uint64(len(kv.key))); err != nil { - return err - } - - if _, err := out.Write(kv.key); err != nil { - return err - } - - if err := packValue(kv.value, out); err != nil { - return err - } - } - - // If there aren't any key-value pairs to pack, ... - if l < 1 { - typ := in.Type() - - // ... create one and pack it - panics on disallowed type. - _ = packValue(reflect.Zero(typ.Key()), io.Discard) - _ = packValue(reflect.Zero(typ.Elem()), io.Discard) - } - - return nil - case reflect.Ptr: - if in.IsNil() { - err := packValue(reflect.Value{}, out) - - // Create a fictive referenced value and pack it - panics on disallowed type. - _ = packValue(reflect.Zero(in.Type().Elem()), io.Discard) - - return err - } else { - return packValue(in.Elem(), out) - } - case reflect.String: - return packString([]byte(in.String()), out) - default: - panic("bad type: " + in.Kind().String()) - } -} - -// packString deduplicates string packing of multiple locations in packValue. -func packString(in []byte, out io.Writer) error { - if _, err := out.Write([]byte{4}); err != nil { - return err - } - - if err := binary.Write(out, binary.BigEndian, uint64(len(in))); err != nil { - return err - } - - _, err := out.Write(in) - return err -} diff --git a/pkg/objectpacker/objectpacker_test.go b/pkg/objectpacker/objectpacker_test.go deleted file mode 100644 index 27389dccc..000000000 --- a/pkg/objectpacker/objectpacker_test.go +++ /dev/null @@ -1,195 +0,0 @@ -package objectpacker - -import ( - "bytes" - "github.com/icinga/icinga-go-library/types" - "github.com/pkg/errors" - "io" - "testing" -) - -// limitedWriter allows writing a specific amount of data. -type limitedWriter struct { - // limit specifies how many bytes to allow to write. - limit int -} - -var _ io.Writer = (*limitedWriter)(nil) - -// Write returns io.EOF once lw.limit is exceeded, nil otherwise. -func (lw *limitedWriter) Write(p []byte) (n int, err error) { - if len(p) <= lw.limit { - lw.limit -= len(p) - return len(p), nil - } - - n = lw.limit - err = io.EOF - - lw.limit = 0 - return -} - -func TestLimitedWriter_Write(t *testing.T) { - assertLimitedWriter_Write(t, 3, []byte{1, 2}, 2, nil, 1) - assertLimitedWriter_Write(t, 3, []byte{1, 2, 3}, 3, nil, 0) - assertLimitedWriter_Write(t, 3, []byte{1, 2, 3, 4}, 3, io.EOF, 0) - assertLimitedWriter_Write(t, 0, []byte{1}, 0, io.EOF, 0) - assertLimitedWriter_Write(t, 0, nil, 0, nil, 0) -} - -func assertLimitedWriter_Write(t *testing.T, limitBefore int, p []byte, n int, err error, limitAfter int) { - t.Helper() - - lw := limitedWriter{limitBefore} - actualN, actualErr := lw.Write(p) - - if !errors.Is(actualErr, err) { - t.Errorf("_, err := (&limitedWriter{%d}).Write(%#v); err != %#v", limitBefore, p, err) - } - - if actualN != n { - t.Errorf("n, _ := (&limitedWriter{%d}).Write(%#v); n != %d", limitBefore, p, n) - } - - if lw.limit != limitAfter { - t.Errorf("lw := limitedWriter{%d}; lw.Write(%#v); lw.limit != %d", limitBefore, p, limitAfter) - } -} - -func TestPackAny(t *testing.T) { - assertPackAny(t, nil, []byte{0}) - assertPackAny(t, false, []byte{1}) - assertPackAny(t, true, []byte{2}) - - assertPackAnyPanic(t, -42, 0) - assertPackAnyPanic(t, int8(-42), 0) - assertPackAnyPanic(t, int16(-42), 0) - assertPackAnyPanic(t, int32(-42), 0) - assertPackAnyPanic(t, int64(-42), 0) - - assertPackAnyPanic(t, uint(42), 0) - assertPackAnyPanic(t, uint8(42), 0) - assertPackAnyPanic(t, uint16(42), 0) - assertPackAnyPanic(t, uint32(42), 0) - assertPackAnyPanic(t, uint64(42), 0) - assertPackAnyPanic(t, uintptr(42), 0) - - assertPackAnyPanic(t, float32(-42.5), 0) - assertPackAny(t, -42.5, []byte{3, 0xc0, 0x45, 0x40, 0, 0, 0, 0, 0}) - - assertPackAnyPanic(t, []struct{}(nil), 9) - assertPackAnyPanic(t, []struct{}{}, 9) - - assertPackAny(t, []interface{}{nil, true, -42.5}, []byte{ - 5, 0, 0, 0, 0, 0, 0, 0, 3, - 0, - 2, - 3, 0xc0, 0x45, 0x40, 0, 0, 0, 0, 0, - }) - - assertPackAny(t, []string{"", "a"}, []byte{ - 5, 0, 0, 0, 0, 0, 0, 0, 2, - 4, 0, 0, 0, 0, 0, 0, 0, 0, - 4, 0, 0, 0, 0, 0, 0, 0, 1, 'a', - }) - - assertPackAnyPanic(t, []interface{}{0 + 0i}, 9) - - assertPackAnyPanic(t, map[struct{}]struct{}(nil), 9) - assertPackAnyPanic(t, map[struct{}]struct{}{}, 9) - - assertPackAny(t, map[interface{}]interface{}{true: "", "nil": -42.5}, []byte{ - 6, 0, 0, 0, 0, 0, 0, 0, 2, - 0, 0, 0, 0, 0, 0, 0, 3, 'n', 'i', 'l', - 3, 0xc0, 0x45, 0x40, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 4, 't', 'r', 'u', 'e', - 4, 0, 0, 0, 0, 0, 0, 0, 0, - }) - - assertPackAny(t, map[string]float64{"": 42}, []byte{ - 6, 0, 0, 0, 0, 0, 0, 0, 1, - 0, 0, 0, 0, 0, 0, 0, 0, - 3, 0x40, 0x45, 0, 0, 0, 0, 0, 0, - }) - - assertPackAny(t, map[[1]byte]bool{{42}: true}, []byte{ - 6, 0, 0, 0, 0, 0, 0, 0, 1, - 0, 0, 0, 0, 0, 0, 0, 1, 42, - 2, - }) - - assertPackAnyPanic(t, map[struct{}]struct{}{{}: {}}, 9) - - assertPackAny(t, (*string)(nil), []byte{0}) - assertPackAnyPanic(t, (*int)(nil), 0) - assertPackAny(t, new(float64), []byte{3, 0, 0, 0, 0, 0, 0, 0, 0}) - - assertPackAny(t, "", []byte{4, 0, 0, 0, 0, 0, 0, 0, 0}) - assertPackAny(t, "a", []byte{4, 0, 0, 0, 0, 0, 0, 0, 1, 'a'}) - assertPackAny(t, "ä", []byte{4, 0, 0, 0, 0, 0, 0, 0, 2, 0xc3, 0xa4}) - - { - var binary [256]byte - for i := range binary { - binary[i] = byte(i) - } - - assertPackAny(t, binary, append([]byte{4, 0, 0, 0, 0, 0, 0, 1, 0}, binary[:]...)) - assertPackAny(t, binary[:], append([]byte{4, 0, 0, 0, 0, 0, 0, 1, 0}, binary[:]...)) - assertPackAny(t, types.Binary(binary[:]), append([]byte{4, 0, 0, 0, 0, 0, 0, 1, 0}, binary[:]...)) - } - - { - type myByte byte - assertPackAnyPanic(t, []myByte(nil), 9) - } - - assertPackAnyPanic(t, complex64(0+0i), 0) - assertPackAnyPanic(t, 0+0i, 0) - assertPackAnyPanic(t, make(chan struct{}), 0) - assertPackAnyPanic(t, func() {}, 0) - assertPackAnyPanic(t, struct{}{}, 0) - assertPackAnyPanic(t, uintptr(0), 0) -} - -func assertPackAny(t *testing.T, in interface{}, out []byte) { - t.Helper() - - { - buf := &bytes.Buffer{} - if err := PackAny(in, buf); err == nil { - if !bytes.Equal(buf.Bytes(), out) { - t.Errorf("buf := &bytes.Buffer{}; packAny(%#v, buf); !bytes.Equal(buf.Bytes(), %#v)", in, out) - } - } else { - t.Errorf("packAny(%#v, &bytes.Buffer{}) != nil", in) - } - } - - for i := 0; i < len(out); i++ { - if !errors.Is(PackAny(in, &limitedWriter{i}), io.EOF) { - t.Errorf("packAny(%#v, &limitedWriter{%d}) != io.EOF", in, i) - } - } -} - -func assertPackAnyPanic(t *testing.T, in interface{}, allowToWrite int) { - t.Helper() - - for i := 0; i < allowToWrite; i++ { - if !errors.Is(PackAny(in, &limitedWriter{i}), io.EOF) { - t.Errorf("packAny(%#v, &limitedWriter{%d}) != io.EOF", in, i) - } - } - - defer func() { - t.Helper() - - if r := recover(); r == nil { - t.Errorf("packAny(%#v, &limitedWriter{%d}) didn't panic", in, allowToWrite) - } - }() - - _ = PackAny(in, &limitedWriter{allowToWrite}) -} diff --git a/pkg/periodic/periodic.go b/pkg/periodic/periodic.go deleted file mode 100644 index 6ef5ceb87..000000000 --- a/pkg/periodic/periodic.go +++ /dev/null @@ -1,123 +0,0 @@ -package periodic - -import ( - "context" - "sync" - "time" -) - -// Option configures Start. -type Option interface { - apply(*periodic) -} - -// Stopper implements the Stop method, -// which stops a periodic task from Start(). -type Stopper interface { - Stop() // Stops a periodic task. -} - -// Tick is the value for periodic task callbacks that -// contains the time of the tick and -// the time elapsed since the start of the periodic task. -type Tick struct { - Elapsed time.Duration - Time time.Time -} - -// Immediate starts the periodic task immediately instead of after the first tick. -func Immediate() Option { - return optionFunc(func(p *periodic) { - p.immediate = true - }) -} - -// OnStop configures a callback that is executed when a periodic task is stopped or canceled. -func OnStop(f func(Tick)) Option { - return optionFunc(func(p *periodic) { - p.onStop = f - }) -} - -// Start starts a periodic task with a ticker at the specified interval, -// which executes the given callback after each tick. -// Pending tasks do not overlap, but could start immediately if -// the previous task(s) takes longer than the interval. -// Call Stop() on the return value in order to stop the ticker and to release associated resources. -// The interval must be greater than zero. -func Start(ctx context.Context, interval time.Duration, callback func(Tick), options ...Option) Stopper { - t := &periodic{ - interval: interval, - callback: callback, - } - - for _, option := range options { - option.apply(t) - } - - ctx, cancelCtx := context.WithCancel(ctx) - - start := time.Now() - - go func() { - done := false - - if !t.immediate { - select { - case <-time.After(interval): - case <-ctx.Done(): - done = true - } - } - - if !done { - ticker := time.NewTicker(t.interval) - defer ticker.Stop() - - for tickTime := time.Now(); !done; { - t.callback(Tick{ - Elapsed: tickTime.Sub(start), - Time: tickTime, - }) - - select { - case tickTime = <-ticker.C: - case <-ctx.Done(): - done = true - } - } - } - - if t.onStop != nil { - now := time.Now() - t.onStop(Tick{ - Elapsed: now.Sub(start), - Time: now, - }) - } - }() - - return stoperFunc(func() { - t.stop.Do(cancelCtx) - }) -} - -type optionFunc func(*periodic) - -func (f optionFunc) apply(p *periodic) { - f(p) -} - -type stoperFunc func() - -func (f stoperFunc) Stop() { - f() -} - -type periodic struct { - interval time.Duration - callback func(Tick) - immediate bool - stop sync.Once - onStop func(Tick) -} diff --git a/pkg/redis/alias.go b/pkg/redis/alias.go deleted file mode 100644 index a8dbdaf46..000000000 --- a/pkg/redis/alias.go +++ /dev/null @@ -1,14 +0,0 @@ -package redis - -import "github.com/redis/go-redis/v9" - -// Alias definitions of commonly used go-redis exports, -// so that only this redis package needs to be imported and not go-redis additionally. - -type IntCmd = redis.IntCmd -type Pipeliner = redis.Pipeliner -type XAddArgs = redis.XAddArgs -type XMessage = redis.XMessage -type XReadArgs = redis.XReadArgs - -var NewScript = redis.NewScript diff --git a/pkg/redis/client.go b/pkg/redis/client.go deleted file mode 100644 index 876d55b15..000000000 --- a/pkg/redis/client.go +++ /dev/null @@ -1,277 +0,0 @@ -package redis - -import ( - "context" - "crypto/tls" - "fmt" - "github.com/icinga/icinga-go-library/backoff" - "github.com/icinga/icinga-go-library/com" - "github.com/icinga/icinga-go-library/logging" - "github.com/icinga/icinga-go-library/periodic" - "github.com/icinga/icinga-go-library/retry" - "github.com/icinga/icinga-go-library/utils" - "github.com/pkg/errors" - "github.com/redis/go-redis/v9" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" - "golang.org/x/sync/semaphore" - "net" - "time" -) - -// Client is a wrapper around redis.Client with -// streaming and logging capabilities. -type Client struct { - *redis.Client - - Options *Options - - logger *logging.Logger -} - -// NewClient returns a new Client wrapper for a pre-existing redis.Client. -func NewClient(client *redis.Client, logger *logging.Logger, options *Options) *Client { - return &Client{Client: client, logger: logger, Options: options} -} - -// NewClientFromConfig returns a new Client from Config. -func NewClientFromConfig(c *Config, logger *logging.Logger) (*Client, error) { - tlsConfig, err := c.TlsOptions.MakeConfig(c.Host) - if err != nil { - return nil, err - } - - var dialer ctxDialerFunc - dl := &net.Dialer{Timeout: 15 * time.Second} - - if tlsConfig == nil { - dialer = dl.DialContext - } else { - dialer = (&tls.Dialer{NetDialer: dl, Config: tlsConfig}).DialContext - } - - options := &redis.Options{ - Dialer: dialWithLogging(dialer, logger), - Password: c.Password, - DB: 0, // Use default DB, - ReadTimeout: c.Options.Timeout, - TLSConfig: tlsConfig, - } - - if utils.IsUnixAddr(c.Host) { - options.Network = "unix" - options.Addr = c.Host - } else { - port := c.Port - if port == 0 { - port = 6379 - } - options.Network = "tcp" - options.Addr = net.JoinHostPort(c.Host, fmt.Sprint(port)) - } - - client := redis.NewClient(options) - options = client.Options() - options.PoolSize = utils.MaxInt(32, options.PoolSize) - options.MaxRetries = options.PoolSize + 1 // https://github.com/go-redis/redis/issues/1737 - - return NewClient(redis.NewClient(options), logger, &c.Options), nil -} - -// GetAddr returns the Redis host:port or Unix socket address. -func (c *Client) GetAddr() string { - return c.Client.Options().Addr -} - -// HPair defines Redis hashes field-value pairs. -type HPair struct { - Field string - Value string -} - -// HYield yields HPair field-value pairs for all fields in the hash stored at key. -func (c *Client) HYield(ctx context.Context, key string) (<-chan HPair, <-chan error) { - pairs := make(chan HPair, c.Options.HScanCount) - - return pairs, com.WaitAsync(com.WaiterFunc(func() error { - var counter com.Counter - defer c.log(ctx, key, &counter).Stop() - defer close(pairs) - - seen := make(map[string]struct{}) - - var cursor uint64 - var err error - var page []string - - for { - cmd := c.HScan(ctx, key, cursor, "", int64(c.Options.HScanCount)) - page, cursor, err = cmd.Result() - - if err != nil { - return WrapCmdErr(cmd) - } - - for i := 0; i < len(page); i += 2 { - if _, ok := seen[page[i]]; ok { - // Ignore duplicate returned by HSCAN. - continue - } - - seen[page[i]] = struct{}{} - - select { - case pairs <- HPair{ - Field: page[i], - Value: page[i+1], - }: - counter.Inc() - case <-ctx.Done(): - return ctx.Err() - } - } - - if cursor == 0 { - break - } - } - - return nil - })) -} - -// HMYield yields HPair field-value pairs for the specified fields in the hash stored at key. -func (c *Client) HMYield(ctx context.Context, key string, fields ...string) (<-chan HPair, <-chan error) { - pairs := make(chan HPair) - - return pairs, com.WaitAsync(com.WaiterFunc(func() error { - var counter com.Counter - defer c.log(ctx, key, &counter).Stop() - - g, ctx := errgroup.WithContext(ctx) - - defer func() { - // Wait until the group is done so that we can safely close the pairs channel, - // because on error, sem.Acquire will return before calling g.Wait(), - // which can result in goroutines working on a closed channel. - _ = g.Wait() - close(pairs) - }() - - // Use context from group. - batches := utils.BatchSliceOfStrings(ctx, fields, c.Options.HMGetCount) - - sem := semaphore.NewWeighted(int64(c.Options.MaxHMGetConnections)) - - for batch := range batches { - if err := sem.Acquire(ctx, 1); err != nil { - return errors.Wrap(err, "can't acquire semaphore") - } - - batch := batch - g.Go(func() error { - defer sem.Release(1) - - cmd := c.HMGet(ctx, key, batch...) - vals, err := cmd.Result() - - if err != nil { - return WrapCmdErr(cmd) - } - - for i, v := range vals { - if v == nil { - c.logger.Warnf("HMGET %s: field %#v missing", key, batch[i]) - continue - } - - select { - case pairs <- HPair{ - Field: batch[i], - Value: v.(string), - }: - counter.Inc() - case <-ctx.Done(): - return ctx.Err() - } - } - - return nil - }) - } - - return g.Wait() - })) -} - -// XReadUntilResult (repeatedly) calls XREAD with the specified arguments until a result is returned. -// Each call blocks at most for the duration specified in Options.BlockTimeout until data -// is available before it times out and the next call is made. -// This also means that an already set block timeout is overridden. -func (c *Client) XReadUntilResult(ctx context.Context, a *redis.XReadArgs) ([]redis.XStream, error) { - a.Block = c.Options.BlockTimeout - - for { - cmd := c.XRead(ctx, a) - streams, err := cmd.Result() - if err != nil { - if errors.Is(err, redis.Nil) { - continue - } - - return streams, WrapCmdErr(cmd) - } - - return streams, nil - } -} - -func (c *Client) log(ctx context.Context, key string, counter *com.Counter) periodic.Stopper { - return periodic.Start(ctx, c.logger.Interval(), func(tick periodic.Tick) { - // We may never get to progress logging here, - // as fetching should be completed before the interval expires, - // but if it does, it is good to have this log message. - if count := counter.Reset(); count > 0 { - c.logger.Debugf("Fetched %d items from %s", count, key) - } - }, periodic.OnStop(func(tick periodic.Tick) { - c.logger.Debugf("Finished fetching from %s with %d items in %s", key, counter.Total(), tick.Elapsed) - })) -} - -type ctxDialerFunc = func(ctx context.Context, network, addr string) (net.Conn, error) - -// dialWithLogging returns a Redis Dialer with logging capabilities. -func dialWithLogging(dialer ctxDialerFunc, logger *logging.Logger) ctxDialerFunc { - // dial behaves like net.Dialer#DialContext, - // but re-tries on common errors that are considered retryable. - return func(ctx context.Context, network, addr string) (conn net.Conn, err error) { - err = retry.WithBackoff( - ctx, - func(ctx context.Context) (err error) { - conn, err = dialer(ctx, network, addr) - return - }, - retry.Retryable, - backoff.NewExponentialWithJitter(1*time.Millisecond, 1*time.Second), - retry.Settings{ - Timeout: retry.DefaultTimeout, - OnRetryableError: func(_ time.Duration, _ uint64, err, lastErr error) { - if lastErr == nil || err.Error() != lastErr.Error() { - logger.Warnw("Can't connect to Redis. Retrying", zap.Error(err)) - } - }, - OnSuccess: func(elapsed time.Duration, attempt uint64, _ error) { - if attempt > 1 { - logger.Infow("Reconnected to Redis", - zap.Duration("after", elapsed), zap.Uint64("attempts", attempt)) - } - }, - }, - ) - - err = errors.Wrap(err, "can't connect to Redis") - - return - } -} diff --git a/pkg/redis/config.go b/pkg/redis/config.go deleted file mode 100644 index d59d37fb2..000000000 --- a/pkg/redis/config.go +++ /dev/null @@ -1,59 +0,0 @@ -package redis - -import ( - "github.com/icinga/icinga-go-library/config" - "github.com/pkg/errors" - "time" -) - -// Options define user configurable Redis options. -type Options struct { - BlockTimeout time.Duration `yaml:"block_timeout" default:"1s"` - HMGetCount int `yaml:"hmget_count" default:"4096"` - HScanCount int `yaml:"hscan_count" default:"4096"` - MaxHMGetConnections int `yaml:"max_hmget_connections" default:"8"` - Timeout time.Duration `yaml:"timeout" default:"30s"` - XReadCount int `yaml:"xread_count" default:"4096"` -} - -// Validate checks constraints in the supplied Redis options and returns an error if they are violated. -func (o *Options) Validate() error { - if o.BlockTimeout <= 0 { - return errors.New("block_timeout must be positive") - } - if o.HMGetCount < 1 { - return errors.New("hmget_count must be at least 1") - } - if o.HScanCount < 1 { - return errors.New("hscan_count must be at least 1") - } - if o.MaxHMGetConnections < 1 { - return errors.New("max_hmget_connections must be at least 1") - } - if o.Timeout == 0 { - return errors.New("timeout cannot be 0. Configure a value greater than zero, or use -1 for no timeout") - } - if o.XReadCount < 1 { - return errors.New("xread_count must be at least 1") - } - - return nil -} - -// Config defines Config client configuration. -type Config struct { - Host string `yaml:"host"` - Port int `yaml:"port"` - Password string `yaml:"password"` - TlsOptions config.TLS `yaml:",inline"` - Options Options `yaml:"options"` -} - -// Validate checks constraints in the supplied Config configuration and returns an error if they are violated. -func (r *Config) Validate() error { - if r.Host == "" { - return errors.New("Redis host missing") - } - - return r.Options.Validate() -} diff --git a/pkg/redis/streams.go b/pkg/redis/streams.go deleted file mode 100644 index 737954369..000000000 --- a/pkg/redis/streams.go +++ /dev/null @@ -1,20 +0,0 @@ -package redis - -// Streams represents a Redis stream key to ID mapping. -type Streams map[string]string - -// Option returns the Redis stream key to ID mapping -// as a slice of stream keys followed by their IDs -// that is compatible for the Redis STREAMS option. -func (s Streams) Option() []string { - // len*2 because we're appending the IDs later. - streams := make([]string, 0, len(s)*2) - ids := make([]string, 0, len(s)) - - for key, id := range s { - streams = append(streams, key) - ids = append(ids, id) - } - - return append(streams, ids...) -} diff --git a/pkg/redis/utils.go b/pkg/redis/utils.go deleted file mode 100644 index 66a38e30b..000000000 --- a/pkg/redis/utils.go +++ /dev/null @@ -1,22 +0,0 @@ -package redis - -import ( - "context" - "github.com/icinga/icinga-go-library/utils" - "github.com/pkg/errors" - "github.com/redis/go-redis/v9" -) - -// WrapCmdErr adds the command itself and -// the stack of the current goroutine to the command's error if any. -func WrapCmdErr(cmd redis.Cmder) error { - err := cmd.Err() - if err != nil { - err = errors.Wrapf(err, "can't perform %q", utils.Ellipsize( - redis.NewCmd(context.Background(), cmd.Args()).String(), // Omits error in opposite to cmd.String() - 100, - )) - } - - return err -} diff --git a/pkg/retry/retry.go b/pkg/retry/retry.go deleted file mode 100644 index 0686b0fbf..000000000 --- a/pkg/retry/retry.go +++ /dev/null @@ -1,201 +0,0 @@ -package retry - -import ( - "context" - "database/sql/driver" - "github.com/go-sql-driver/mysql" - "github.com/icinga/icinga-go-library/backoff" - "github.com/lib/pq" - "github.com/pkg/errors" - "io" - "net" - "syscall" - "time" -) - -// DefaultTimeout is our opinionated default timeout for retrying database and Redis operations. -const DefaultTimeout = 5 * time.Minute - -// RetryableFunc is a retryable function. -type RetryableFunc func(context.Context) error - -// IsRetryable checks whether a new attempt can be started based on the error passed. -type IsRetryable func(error) bool - -// OnRetryableErrorFunc is called if a retryable error occurs. -type OnRetryableErrorFunc func(elapsed time.Duration, attempt uint64, err, lastErr error) - -// OnSuccessFunc is called once the operation succeeds. -type OnSuccessFunc func(elapsed time.Duration, attempt uint64, lastErr error) - -// Settings aggregates optional settings for WithBackoff. -type Settings struct { - // If >0, Timeout lets WithBackoff stop retrying gracefully once elapsed based on the following criteria: - // * If the execution of RetryableFunc has taken longer than Timeout, no further attempts are made. - // * If Timeout elapses during the sleep phase between retries, one final retry is attempted. - // * RetryableFunc is always granted its full execution time and is not canceled if it exceeds Timeout. - // This means that WithBackoff may not stop exactly after Timeout expires, - // or may not retry at all if the first execution of RetryableFunc already takes longer than Timeout. - Timeout time.Duration - OnRetryableError OnRetryableErrorFunc - OnSuccess OnSuccessFunc -} - -// WithBackoff retries the passed function if it fails and the error allows it to retry. -// The specified backoff policy is used to determine how long to sleep between attempts. -func WithBackoff( - ctx context.Context, retryableFunc RetryableFunc, retryable IsRetryable, b backoff.Backoff, settings Settings, -) (err error) { - // Channel for retry deadline, which is set to the channel of NewTimer() if a timeout is configured, - // otherwise nil, so that it blocks forever if there is no timeout. - var timeout <-chan time.Time - - if settings.Timeout > 0 { - t := time.NewTimer(settings.Timeout) - defer t.Stop() - timeout = t.C - } - - start := time.Now() - timedOut := false - for attempt := uint64(1); ; /* true */ attempt++ { - prevErr := err - - if err = retryableFunc(ctx); err == nil { - if settings.OnSuccess != nil { - settings.OnSuccess(time.Since(start), attempt, prevErr) - } - - return - } - - // Retryable function may have exited prematurely due to context errors. - // We explicitly check the context error here, as the error returned by the retryable function can pass the - // error.Is() checks even though it is not a real context error, e.g. - // https://cs.opensource.google/go/go/+/refs/tags/go1.22.2:src/net/net.go;l=422 - // https://cs.opensource.google/go/go/+/refs/tags/go1.22.2:src/net/net.go;l=601 - if errors.Is(ctx.Err(), context.DeadlineExceeded) || errors.Is(ctx.Err(), context.Canceled) { - if prevErr != nil { - err = errors.Wrap(err, prevErr.Error()) - } - - return - } - - if !retryable(err) { - err = errors.Wrap(err, "can't retry") - - return - } - - select { - case <-timeout: - // Stop retrying immediately if executing the retryable function took longer than the timeout. - timedOut = true - default: - } - - if timedOut { - err = errors.Wrap(err, "retry deadline exceeded") - - return - } - - if settings.OnRetryableError != nil { - settings.OnRetryableError(time.Since(start), attempt, err, prevErr) - } - - select { - case <-time.After(b(attempt)): - case <-timeout: - // Do not stop retrying immediately, but start one last attempt to mitigate timing issues where - // the timeout expires while waiting for the next attempt and - // therefore no retries have happened during this possibly long period. - timedOut = true - case <-ctx.Done(): - err = errors.Wrap(ctx.Err(), err.Error()) - - return - } - } -} - -// ResetTimeout changes the possibly expired timer t to expire after duration d. -// -// If the timer has already expired and nothing has been received from its channel, -// it is automatically drained as if the timer had never expired. -func ResetTimeout(t *time.Timer, d time.Duration) { - if !t.Stop() { - <-t.C - } - - t.Reset(d) -} - -// Retryable returns true for common errors that are considered retryable, -// i.e. temporary, timeout, DNS, connection refused and reset, host down and unreachable and -// network down and unreachable errors. In addition, any database error is considered retryable. -func Retryable(err error) bool { - var temporary interface { - Temporary() bool - } - if errors.As(err, &temporary) && temporary.Temporary() { - return true - } - - var timeout interface { - Timeout() bool - } - if errors.As(err, &timeout) && timeout.Timeout() { - return true - } - - var dnsError *net.DNSError - if errors.As(err, &dnsError) { - return true - } - - var opError *net.OpError - if errors.As(err, &opError) { - // OpError provides Temporary() and Timeout(), but not Unwrap(), - // so we have to extract the underlying error ourselves to also check for ECONNREFUSED, - // which is not considered temporary or timed out by Go. - err = opError.Err - } - if errors.Is(err, syscall.ECONNREFUSED) || errors.Is(err, syscall.ENOENT) { - // syscall errors provide Temporary() and Timeout(), - // which do not include ECONNREFUSED or ENOENT, so we check these ourselves. - return true - } - if errors.Is(err, syscall.ECONNRESET) { - // ECONNRESET is treated as a temporary error by Go only if it comes from calling accept. - return true - } - if errors.Is(err, syscall.EHOSTDOWN) || errors.Is(err, syscall.EHOSTUNREACH) { - return true - } - if errors.Is(err, syscall.ENETDOWN) || errors.Is(err, syscall.ENETUNREACH) { - return true - } - if errors.Is(err, syscall.EPIPE) { - return true - } - if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { - return true - } - - if errors.Is(err, driver.ErrBadConn) { - return true - } - if errors.Is(err, mysql.ErrInvalidConn) { - return true - } - - var mye *mysql.MySQLError - var pqe *pq.Error - if errors.As(err, &mye) || errors.As(err, &pqe) { - return true - } - - return false -} diff --git a/pkg/strcase/strcase.go b/pkg/strcase/strcase.go deleted file mode 100644 index 45726cfd3..000000000 --- a/pkg/strcase/strcase.go +++ /dev/null @@ -1,56 +0,0 @@ -// Package strcase implements functions to convert a camelCase UTF-8 string into various cases. -// -// New delimiters will be inserted based on the following transitions: -// - On any change from lowercase to uppercase letter. -// - On any change from number to uppercase letter. -package strcase - -import ( - "strings" - "unicode" -) - -// Delimited converts a string to delimited.lower.case, here using `.` as delimiter. -func Delimited(s string, d rune) string { - return convert(s, unicode.LowerCase, d) -} - -// ScreamingDelimited converts a string to DELIMITED.UPPER.CASE, here using `.` as delimiter. -func ScreamingDelimited(s string, d rune) string { - return convert(s, unicode.UpperCase, d) -} - -// Snake converts a string to snake_case. -func Snake(s string) string { - return Delimited(s, '_') -} - -// ScreamingSnake converts a string to SCREAMING_SNAKE_CASE. -func ScreamingSnake(s string) string { - return ScreamingDelimited(s, '_') -} - -// convert converts a camelCase UTF-8 string into various cases. -// _case must be unicode.LowerCase or unicode.UpperCase. -func convert(s string, _case int, d rune) string { - if len(s) == 0 { - return s - } - - n := strings.Builder{} - n.Grow(len(s) + 2) // Allow adding at least 2 delimiters without another allocation. - - var prevRune rune - - for i, r := range s { - if i > 0 && unicode.IsUpper(r) && (unicode.IsNumber(prevRune) || unicode.IsLower(prevRune)) { - n.WriteRune(d) - } - - n.WriteRune(unicode.To(_case, r)) - - prevRune = r - } - - return n.String() -} diff --git a/pkg/strcase/strcase_test.go b/pkg/strcase/strcase_test.go deleted file mode 100644 index 382db8d74..000000000 --- a/pkg/strcase/strcase_test.go +++ /dev/null @@ -1,58 +0,0 @@ -package strcase - -import ( - "strings" - "testing" -) - -var tests = [][]string{ - {"", ""}, - {"Test", "test"}, - {"test", "test"}, - {"testCase", "test_case"}, - {"test_case", "test_case"}, - {"TestCase", "test_case"}, - {"Test_Case", "test_case"}, - {"ID", "id"}, - {"userID", "user_id"}, - {"UserID", "user_id"}, - {"ManyManyWords", "many_many_words"}, - {"manyManyWords", "many_many_words"}, - {"icinga2", "icinga2"}, - {"Icinga2Version", "icinga2_version"}, - {"k8sVersion", "k8s_version"}, - {"1234", "1234"}, - {"a1b2c3d4", "a1b2c3d4"}, - {"with1234digits", "with1234digits"}, - {"with1234Digits", "with1234_digits"}, - {"IPv4", "ipv4"}, - {"IPv4Address", "ipv4_address"}, - {"caféCrème", "café_crème"}, - {"0℃", "0℃"}, - {"~0", "~0"}, - {"icinga💯points", "icinga💯points"}, - {"😃🙃😀", "😃🙃😀"}, - {"こんにちは", "こんにちは"}, - {"\xff\xfe\xfd", "���"}, - {"\xff", "�"}, -} - -func TestSnake(t *testing.T) { - for _, test := range tests { - s, expected := test[0], test[1] - actual := Snake(s) - if actual != expected { - t.Errorf("%q: %q != %q", s, actual, expected) - } - } -} - -func TestScreamingSnake(t *testing.T) { - for _, test := range tests { - s, expected := test[0], strings.ToUpper(test[1]) - actual := ScreamingSnake(s) - if actual != expected { - t.Errorf("%q: %q != %q", s, actual, expected) - } - } -} diff --git a/pkg/structify/structify.go b/pkg/structify/structify.go deleted file mode 100644 index 94d75267e..000000000 --- a/pkg/structify/structify.go +++ /dev/null @@ -1,176 +0,0 @@ -package structify - -import ( - "encoding" - "fmt" - "github.com/pkg/errors" - "golang.org/x/exp/constraints" - "reflect" - "strconv" - "strings" - "unsafe" -) - -// structBranch represents either a leaf or a subTree. -type structBranch struct { - // field specifies the struct field index. - field int - // leaf specifies the map key to parse the struct field from. - leaf string - // subTree specifies the struct field's inner tree. - subTree []structBranch -} - -type MapStructifier = func(map[string]interface{}) (interface{}, error) - -// MakeMapStructifier builds a function which parses a map's string values into a new struct of type t -// and returns a pointer to it. tag specifies which tag connects struct fields to map keys. -// MakeMapStructifier panics if it detects an unsupported type (suitable for usage in init() or global vars). -func MakeMapStructifier(t reflect.Type, tag string, initer func(any)) MapStructifier { - tree := buildStructTree(t, tag) - - return func(kv map[string]interface{}) (interface{}, error) { - vPtr := reflect.New(t) - ptr := vPtr.Interface() - if initer != nil { - initer(ptr) - } - vPtrElem := vPtr.Elem() - err := errors.Wrapf(structifyMapByTree(kv, tree, vPtrElem, vPtrElem, new([]int)), "can't structify map %#v by tree %#v", kv, tree) - - return ptr, err - } -} - -// buildStructTree assembles a tree which represents the struct t based on tag. -func buildStructTree(t reflect.Type, tag string) []structBranch { - var tree []structBranch - numFields := t.NumField() - - for i := 0; i < numFields; i++ { - if field := t.Field(i); field.PkgPath == "" { - switch tagValue := field.Tag.Get(tag); tagValue { - case "", "-": - case ",inline": - if subTree := buildStructTree(field.Type, tag); subTree != nil { - tree = append(tree, structBranch{i, "", subTree}) - } - default: - // If parseString doesn't support *T, it'll panic. - _ = parseString("", reflect.New(field.Type).Interface()) - - tree = append(tree, structBranch{i, tagValue, nil}) - } - } - } - - return tree -} - -// structifyMapByTree parses src's string values into the struct dest according to tree's specification. -func structifyMapByTree(src map[string]interface{}, tree []structBranch, dest, root reflect.Value, stack *[]int) error { - *stack = append(*stack, 0) - defer func() { - *stack = (*stack)[:len(*stack)-1] - }() - - for _, branch := range tree { - (*stack)[len(*stack)-1] = branch.field - - if branch.subTree == nil { - if v, ok := src[branch.leaf]; ok { - if vs, ok := v.(string); ok { - if err := parseString(vs, dest.Field(branch.field).Addr().Interface()); err != nil { - rt := root.Type() - typ := rt - var path []string - - for _, i := range *stack { - f := typ.Field(i) - path = append(path, f.Name) - typ = f.Type - } - - return errors.Wrapf(err, "can't parse %s into the %s %s#%s: %s", - branch.leaf, typ.Name(), rt.Name(), strings.Join(path, "."), vs) - } - } - } - } else if err := structifyMapByTree(src, branch.subTree, dest.Field(branch.field), root, stack); err != nil { - return err - } - } - - return nil -} - -// parseString parses src into *dest. -func parseString(src string, dest interface{}) error { - switch ptr := dest.(type) { - case encoding.TextUnmarshaler: - return ptr.UnmarshalText([]byte(src)) - case *string: - *ptr = src - return nil - case **string: - *ptr = &src - return nil - case *uint8: - return parseUint(src, ptr) - case *uint16: - return parseUint(src, ptr) - case *uint32: - return parseUint(src, ptr) - case *uint64: - return parseUint(src, ptr) - case *int8: - return parseInt(src, ptr) - case *int16: - return parseInt(src, ptr) - case *int32: - return parseInt(src, ptr) - case *int64: - return parseInt(src, ptr) - case *float32: - return parseFloat(src, ptr) - case *float64: - return parseFloat(src, ptr) - default: - panic(fmt.Sprintf("unsupported type: %T", dest)) - } -} - -// parseUint parses src into *dest. -func parseUint[T constraints.Unsigned](src string, dest *T) error { - i, err := strconv.ParseUint(src, 10, bitSizeOf[T]()) - if err == nil { - *dest = T(i) - } - - return err -} - -// parseInt parses src into *dest. -func parseInt[T constraints.Signed](src string, dest *T) error { - i, err := strconv.ParseInt(src, 10, bitSizeOf[T]()) - if err == nil { - *dest = T(i) - } - - return err -} - -// parseFloat parses src into *dest. -func parseFloat[T constraints.Float](src string, dest *T) error { - f, err := strconv.ParseFloat(src, bitSizeOf[T]()) - if err == nil { - *dest = T(f) - } - - return err -} - -func bitSizeOf[T any]() int { - var x T - return int(unsafe.Sizeof(x) * 8) -} diff --git a/pkg/types/binary.go b/pkg/types/binary.go deleted file mode 100644 index 2a786afef..000000000 --- a/pkg/types/binary.go +++ /dev/null @@ -1,125 +0,0 @@ -package types - -import ( - "bytes" - "database/sql" - "database/sql/driver" - "encoding" - "encoding/hex" - "encoding/json" - "fmt" - "github.com/pkg/errors" -) - -// Binary nullable byte string. Hex as JSON. -type Binary []byte - -// nullBinary for validating whether a Binary is valid. -var nullBinary Binary - -// Valid returns whether the Binary is valid. -func (binary Binary) Valid() bool { - return !bytes.Equal(binary, nullBinary) -} - -// String returns the hex string representation form of the Binary. -func (binary Binary) String() string { - return hex.EncodeToString(binary) -} - -// MarshalText implements a custom marhsal function to encode -// the Binary as hex. MarshalText implements the -// encoding.TextMarshaler interface. -func (binary Binary) MarshalText() ([]byte, error) { - return []byte(binary.String()), nil -} - -// UnmarshalText implements a custom unmarshal function to decode -// hex into a Binary. UnmarshalText implements the -// encoding.TextUnmarshaler interface. -func (binary *Binary) UnmarshalText(text []byte) error { - b := make([]byte, hex.DecodedLen(len(text))) - _, err := hex.Decode(b, text) - if err != nil { - return CantDecodeHex(err, string(text)) - } - *binary = b - - return nil -} - -// MarshalJSON implements a custom marshal function to encode the Binary -// as a hex string. MarshalJSON implements the json.Marshaler interface. -// Supports JSON null. -func (binary Binary) MarshalJSON() ([]byte, error) { - if !binary.Valid() { - return []byte("null"), nil - } - - return MarshalJSON(binary.String()) -} - -// UnmarshalJSON implements a custom unmarshal function to decode -// a JSON hex string into a Binary. UnmarshalJSON implements the -// json.Unmarshaler interface. Supports JSON null. -func (binary *Binary) UnmarshalJSON(data []byte) error { - if string(data) == "null" || len(data) == 0 { - return nil - } - - var s string - if err := UnmarshalJSON(data, &s); err != nil { - return err - } - b, err := hex.DecodeString(s) - if err != nil { - return CantDecodeHex(err, s) - } - *binary = b - - return nil -} - -// Scan implements the sql.Scanner interface. -// Supports SQL NULL. -func (binary *Binary) Scan(src interface{}) error { - switch src := src.(type) { - case nil: - return nil - - case []byte: - if len(src) == 0 { - return nil - } - - b := make([]byte, len(src)) - copy(b, src) - *binary = b - - default: - return errors.Errorf("unable to scan type %T into Binary", src) - } - - return nil -} - -// Value implements the driver.Valuer interface. -// Supports SQL NULL. -func (binary Binary) Value() (driver.Value, error) { - if !binary.Valid() { - return nil, nil - } - - return []byte(binary), nil -} - -// Assert interface compliance. -var ( - _ fmt.Stringer = (*Binary)(nil) - _ encoding.TextMarshaler = (*Binary)(nil) - _ encoding.TextUnmarshaler = (*Binary)(nil) - _ json.Marshaler = (*Binary)(nil) - _ json.Unmarshaler = (*Binary)(nil) - _ sql.Scanner = (*Binary)(nil) - _ driver.Valuer = (*Binary)(nil) -) diff --git a/pkg/types/binary_test.go b/pkg/types/binary_test.go deleted file mode 100644 index 2a4f82920..000000000 --- a/pkg/types/binary_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package types - -import ( - "github.com/stretchr/testify/require" - "testing" - "unicode/utf8" -) - -func TestBinary_MarshalJSON(t *testing.T) { - subtests := []struct { - name string - input Binary - output string - }{ - {"nil", nil, `null`}, - {"empty", make(Binary, 0, 1), `null`}, - {"space", Binary(" "), `"20"`}, - } - - for _, st := range subtests { - t.Run(st.name, func(t *testing.T) { - actual, err := st.input.MarshalJSON() - - require.NoError(t, err) - require.True(t, utf8.Valid(actual)) - require.Equal(t, st.output, string(actual)) - }) - } -} diff --git a/pkg/types/bool.go b/pkg/types/bool.go deleted file mode 100644 index cd0af2cea..000000000 --- a/pkg/types/bool.go +++ /dev/null @@ -1,104 +0,0 @@ -package types - -import ( - "database/sql" - "database/sql/driver" - "encoding" - "encoding/json" - "github.com/pkg/errors" - "strconv" -) - -var ( - enum = map[bool]string{ - true: "y", - false: "n", - } -) - -// Bool represents a bool for ENUM ('y', 'n'), which can be NULL. -type Bool struct { - Bool bool - Valid bool // Valid is true if Bool is not NULL -} - -// MarshalJSON implements the json.Marshaler interface. -func (b Bool) MarshalJSON() ([]byte, error) { - if !b.Valid { - return []byte("null"), nil - } - - return MarshalJSON(b.Bool) -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface. -func (b *Bool) UnmarshalText(text []byte) error { - parsed, err := strconv.ParseUint(string(text), 10, 64) - if err != nil { - return CantParseUint64(err, string(text)) - } - - *b = Bool{parsed != 0, true} - return nil -} - -// UnmarshalJSON implements the json.Unmarshaler interface. -func (b *Bool) UnmarshalJSON(data []byte) error { - if string(data) == "null" || len(data) == 0 { - return nil - } - - if err := UnmarshalJSON(data, &b.Bool); err != nil { - return err - } - - b.Valid = true - - return nil -} - -// Scan implements the sql.Scanner interface. -// Supports SQL NULL. -func (b *Bool) Scan(src interface{}) error { - if src == nil { - b.Bool, b.Valid = false, false - return nil - } - - v, ok := src.([]byte) - if !ok { - return errors.Errorf("bad []byte type assertion from %#v", src) - } - - switch string(v) { - case "y": - b.Bool = true - case "n": - b.Bool = false - default: - return errors.Errorf("bad bool %#v", v) - } - - b.Valid = true - - return nil -} - -// Value implements the driver.Valuer interface. -// Supports SQL NULL. -func (b Bool) Value() (driver.Value, error) { - if !b.Valid { - return nil, nil - } - - return enum[b.Bool], nil -} - -// Assert interface compliance. -var ( - _ json.Marshaler = (*Bool)(nil) - _ encoding.TextUnmarshaler = (*Bool)(nil) - _ json.Unmarshaler = (*Bool)(nil) - _ sql.Scanner = (*Bool)(nil) - _ driver.Valuer = (*Bool)(nil) -) diff --git a/pkg/types/bool_test.go b/pkg/types/bool_test.go deleted file mode 100644 index fe49588c8..000000000 --- a/pkg/types/bool_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package types - -import ( - "fmt" - "github.com/stretchr/testify/require" - "testing" - "unicode/utf8" -) - -func TestBool_MarshalJSON(t *testing.T) { - subtests := []struct { - input Bool - output string - }{ - {Bool{Bool: false, Valid: false}, `null`}, - {Bool{Bool: false, Valid: true}, `false`}, - {Bool{Bool: true, Valid: false}, `null`}, - {Bool{Bool: true, Valid: true}, `true`}, - } - - for _, st := range subtests { - t.Run(fmt.Sprintf("Bool-%#v_Valid-%#v", st.input.Bool, st.input.Valid), func(t *testing.T) { - actual, err := st.input.MarshalJSON() - - require.NoError(t, err) - require.True(t, utf8.Valid(actual)) - require.Equal(t, st.output, string(actual)) - }) - } -} diff --git a/pkg/types/float.go b/pkg/types/float.go deleted file mode 100644 index dcd51c6f0..000000000 --- a/pkg/types/float.go +++ /dev/null @@ -1,67 +0,0 @@ -package types - -import ( - "bytes" - "database/sql" - "database/sql/driver" - "encoding" - "encoding/json" - "strconv" -) - -// Float adds JSON support to sql.NullFloat64. -type Float struct { - sql.NullFloat64 -} - -// MarshalJSON implements the json.Marshaler interface. -// Supports JSON null. -func (f Float) MarshalJSON() ([]byte, error) { - var v interface{} - if f.Valid { - v = f.Float64 - } - - return MarshalJSON(v) -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface. -func (f *Float) UnmarshalText(text []byte) error { - parsed, err := strconv.ParseFloat(string(text), 64) - if err != nil { - return CantParseFloat64(err, string(text)) - } - - *f = Float{sql.NullFloat64{ - Float64: parsed, - Valid: true, - }} - - return nil -} - -// UnmarshalJSON implements the json.Unmarshaler interface. -// Supports JSON null. -func (f *Float) UnmarshalJSON(data []byte) error { - // Ignore null, like in the main JSON package. - if bytes.HasPrefix(data, []byte{'n'}) { - return nil - } - - if err := UnmarshalJSON(data, &f.Float64); err != nil { - return err - } - - f.Valid = true - - return nil -} - -// Assert interface compliance. -var ( - _ json.Marshaler = Float{} - _ encoding.TextUnmarshaler = (*Float)(nil) - _ json.Unmarshaler = (*Float)(nil) - _ driver.Valuer = Float{} - _ sql.Scanner = (*Float)(nil) -) diff --git a/pkg/types/int.go b/pkg/types/int.go deleted file mode 100644 index 448180f16..000000000 --- a/pkg/types/int.go +++ /dev/null @@ -1,67 +0,0 @@ -package types - -import ( - "bytes" - "database/sql" - "database/sql/driver" - "encoding" - "encoding/json" - "strconv" -) - -// Int adds JSON support to sql.NullInt64. -type Int struct { - sql.NullInt64 -} - -// MarshalJSON implements the json.Marshaler interface. -// Supports JSON null. -func (i Int) MarshalJSON() ([]byte, error) { - var v interface{} - if i.Valid { - v = i.Int64 - } - - return MarshalJSON(v) -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface. -func (i *Int) UnmarshalText(text []byte) error { - parsed, err := strconv.ParseInt(string(text), 10, 64) - if err != nil { - return CantParseInt64(err, string(text)) - } - - *i = Int{sql.NullInt64{ - Int64: parsed, - Valid: true, - }} - - return nil -} - -// UnmarshalJSON implements the json.Unmarshaler interface. -// Supports JSON null. -func (i *Int) UnmarshalJSON(data []byte) error { - // Ignore null, like in the main JSON package. - if bytes.HasPrefix(data, []byte{'n'}) { - return nil - } - - if err := UnmarshalJSON(data, &i.Int64); err != nil { - return err - } - - i.Valid = true - - return nil -} - -// Assert interface compliance. -var ( - _ json.Marshaler = Int{} - _ json.Unmarshaler = (*Int)(nil) - _ encoding.TextUnmarshaler = (*Int)(nil) - _ driver.Valuer = Int{} - _ sql.Scanner = (*Int)(nil) -) diff --git a/pkg/types/string.go b/pkg/types/string.go deleted file mode 100644 index c01bf964f..000000000 --- a/pkg/types/string.go +++ /dev/null @@ -1,81 +0,0 @@ -package types - -import ( - "bytes" - "database/sql" - "database/sql/driver" - "encoding" - "encoding/json" - "strings" -) - -// String adds JSON support to sql.NullString. -type String struct { - sql.NullString -} - -// MakeString constructs a new non-NULL String from s. -func MakeString(s string) String { - return String{sql.NullString{ - String: s, - Valid: true, - }} -} - -// MarshalJSON implements the json.Marshaler interface. -// Supports JSON null. -func (s String) MarshalJSON() ([]byte, error) { - var v interface{} - if s.Valid { - v = s.String - } - - return MarshalJSON(v) -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface. -func (s *String) UnmarshalText(text []byte) error { - *s = String{sql.NullString{ - String: string(text), - Valid: true, - }} - - return nil -} - -// UnmarshalJSON implements the json.Unmarshaler interface. -// Supports JSON null. -func (s *String) UnmarshalJSON(data []byte) error { - // Ignore null, like in the main JSON package. - if bytes.HasPrefix(data, []byte{'n'}) { - return nil - } - - if err := UnmarshalJSON(data, &s.String); err != nil { - return err - } - - s.Valid = true - - return nil -} - -// Value implements the driver.Valuer interface. -// Supports SQL NULL. -func (s String) Value() (driver.Value, error) { - if !s.Valid { - return nil, nil - } - - // PostgreSQL does not allow null bytes in varchar, char and text fields. - return strings.ReplaceAll(s.String, "\x00", ""), nil -} - -// Assert interface compliance. -var ( - _ json.Marshaler = String{} - _ encoding.TextUnmarshaler = (*String)(nil) - _ json.Unmarshaler = (*String)(nil) - _ driver.Valuer = String{} - _ sql.Scanner = (*String)(nil) -) diff --git a/pkg/types/unix_milli.go b/pkg/types/unix_milli.go deleted file mode 100644 index 943f07992..000000000 --- a/pkg/types/unix_milli.go +++ /dev/null @@ -1,116 +0,0 @@ -package types - -import ( - "bytes" - "database/sql" - "database/sql/driver" - "encoding" - "encoding/json" - "github.com/pkg/errors" - "math" - "strconv" - "time" -) - -// UnixMilli is a nullable millisecond UNIX timestamp in databases and JSON. -type UnixMilli time.Time - -// Time returns the time.Time conversion of UnixMilli. -func (t UnixMilli) Time() time.Time { - return time.Time(t) -} - -// MarshalJSON implements the json.Marshaler interface. -// Marshals to milliseconds. Supports JSON null. -func (t UnixMilli) MarshalJSON() ([]byte, error) { - if time.Time(t).IsZero() { - return []byte("null"), nil - } - - return []byte(strconv.FormatInt(t.Time().UnixMilli(), 10)), nil -} - -// UnmarshalJSON implements the json.Unmarshaler interface. -// Unmarshals from milliseconds. Supports JSON null. -func (t *UnixMilli) UnmarshalJSON(data []byte) error { - if bytes.Equal(data, []byte("null")) || len(data) == 0 { - return nil - } - - return t.fromByteString(data) -} - -// MarshalText implements the encoding.TextMarshaler interface. -func (t UnixMilli) MarshalText() ([]byte, error) { - if time.Time(t).IsZero() { - return []byte{}, nil - } - - return []byte(strconv.FormatInt(t.Time().UnixMilli(), 10)), nil -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface. -func (t *UnixMilli) UnmarshalText(text []byte) error { - if len(text) == 0 { - return nil - } - - return t.fromByteString(text) -} - -// Scan implements the sql.Scanner interface. -// Scans from milliseconds. Supports SQL NULL. -func (t *UnixMilli) Scan(src interface{}) error { - if src == nil { - return nil - } - - switch v := src.(type) { - case []byte: - return t.fromByteString(v) - // https://github.com/go-sql-driver/mysql/pull/1452 - case uint64: - if v > math.MaxInt64 { - return errors.Errorf("value %v out of range for int64", v) - } - - *t = UnixMilli(time.UnixMilli(int64(v))) - case int64: - *t = UnixMilli(time.UnixMilli(v)) - default: - return errors.Errorf("bad (u)int64/[]byte type assertion from %[1]v (%[1]T)", src) - } - - return nil -} - -// Value implements the driver.Valuer interface. -// Returns milliseconds. Supports SQL NULL. -func (t UnixMilli) Value() (driver.Value, error) { - if t.Time().IsZero() { - return nil, nil - } - - return t.Time().UnixMilli(), nil -} - -func (t *UnixMilli) fromByteString(data []byte) error { - i, err := strconv.ParseInt(string(data), 10, 64) - if err != nil { - return CantParseInt64(err, string(data)) - } - - *t = UnixMilli(time.UnixMilli(i)) - - return nil -} - -// Assert interface compliance. -var ( - _ encoding.TextMarshaler = (*UnixMilli)(nil) - _ encoding.TextUnmarshaler = (*UnixMilli)(nil) - _ json.Marshaler = (*UnixMilli)(nil) - _ json.Unmarshaler = (*UnixMilli)(nil) - _ driver.Valuer = (*UnixMilli)(nil) - _ sql.Scanner = (*UnixMilli)(nil) -) diff --git a/pkg/types/unix_milli_test.go b/pkg/types/unix_milli_test.go deleted file mode 100644 index 0cdbc7cd2..000000000 --- a/pkg/types/unix_milli_test.go +++ /dev/null @@ -1,149 +0,0 @@ -package types - -import ( - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "math" - "testing" - "time" - "unicode/utf8" -) - -func TestUnixMilli(t *testing.T) { - type testCase struct { - v UnixMilli - json string - text string - } - - tests := map[string]testCase{ - "Zero": {UnixMilli{}, "null", ""}, - "Non-zero": {UnixMilli(time.Unix(1234567890, 0)), "1234567890000", "1234567890000"}, - "Epoch": {UnixMilli(time.Unix(0, 0)), "0", "0"}, - "With milliseconds": {UnixMilli(time.Unix(1234567890, 62000000)), "1234567890062", "1234567890062"}, - } - - var runTests = func(t *testing.T, f func(*testing.T, testCase)) { - for name, test := range tests { - t.Run(name, func(t *testing.T) { - f(t, test) - }) - } - } - - t.Run("MarshalJSON", func(t *testing.T) { - runTests(t, func(t *testing.T, test testCase) { - actual, err := test.v.MarshalJSON() - require.NoError(t, err) - require.True(t, utf8.Valid(actual)) - require.Equal(t, test.json, string(actual)) - }) - }) - - t.Run("UnmarshalJSON", func(t *testing.T) { - runTests(t, func(t *testing.T, test testCase) { - var actual UnixMilli - err := actual.UnmarshalJSON([]byte(test.json)) - require.NoError(t, err) - require.Equal(t, test.v, actual) - }) - }) - - t.Run("MarshalText", func(t *testing.T) { - runTests(t, func(t *testing.T, test testCase) { - actual, err := test.v.MarshalText() - require.NoError(t, err) - require.True(t, utf8.Valid(actual)) - require.Equal(t, test.text, string(actual)) - }) - }) - - t.Run("UnmarshalText", func(t *testing.T) { - runTests(t, func(t *testing.T, test testCase) { - var actual UnixMilli - err := actual.UnmarshalText([]byte(test.text)) - require.NoError(t, err) - require.Equal(t, test.v, actual) - }) - }) -} - -func TestUnixMilli_Scan(t *testing.T) { - tests := []struct { - name string - v any - expected UnixMilli - expectErr bool - }{ - { - name: "Nil", - v: nil, - expected: UnixMilli{}, - }, - { - name: "Epoch", - v: int64(0), - expected: UnixMilli(time.Unix(0, 0)), - }, - { - name: "bytes", - v: []byte("1234567890062"), - expected: UnixMilli(time.Unix(1234567890, 62000000)), - }, - { - name: "Invalid bytes", - v: []byte("invalid"), - expectErr: true, - }, - { - name: "int64", - v: int64(1234567890062), - expected: UnixMilli(time.Unix(1234567890, 62000000)), - }, - { - name: "uint64", - v: uint64(1234567890062), - expected: UnixMilli(time.Unix(1234567890, 62000000)), - }, - { - name: "uint64 out of range for int64", - v: uint64(math.MaxInt64) + 1, - expectErr: true, - }, - { - name: "Invalid type", - v: "invalid", - expectErr: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var actual UnixMilli - err := actual.Scan(test.v) - if test.expectErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, test.expected, actual) - } - }) - } -} - -func TestUnixMilli_Value(t *testing.T) { - t.Run("Zero", func(t *testing.T) { - var zero UnixMilli - actual, err := zero.Value() - require.NoError(t, err) - require.Nil(t, actual) - }) - - t.Run("Non-zero", func(t *testing.T) { - withMilliseconds := time.Unix(1234567890, 62000000) - expected := withMilliseconds.UnixMilli() - actual, err := UnixMilli(withMilliseconds).Value() - assert.NoError(t, err) - assert.Equal(t, expected, actual) - }) -} diff --git a/pkg/types/utils.go b/pkg/types/utils.go deleted file mode 100644 index 35c44aa42..000000000 --- a/pkg/types/utils.go +++ /dev/null @@ -1,52 +0,0 @@ -package types - -import ( - "encoding/json" - "fmt" - "github.com/pkg/errors" - "strings" -) - -// Name returns the declared name of type t. -func Name(t any) string { - s := strings.TrimLeft(fmt.Sprintf("%T", t), "*") - - return s[strings.LastIndex(s, ".")+1:] -} - -// CantDecodeHex wraps the given error with the given string that cannot be hex-decoded. -func CantDecodeHex(err error, s string) error { - return errors.Wrapf(err, "can't decode hex %q", s) -} - -// CantParseFloat64 wraps the given error with the specified string that cannot be parsed into float64. -func CantParseFloat64(err error, s string) error { - return errors.Wrapf(err, "can't parse %q into float64", s) -} - -// CantParseInt64 wraps the given error with the specified string that cannot be parsed into int64. -func CantParseInt64(err error, s string) error { - return errors.Wrapf(err, "can't parse %q into int64", s) -} - -// CantParseUint64 wraps the given error with the specified string that cannot be parsed into uint64. -func CantParseUint64(err error, s string) error { - return errors.Wrapf(err, "can't parse %q into uint64", s) -} - -// CantUnmarshalYAML wraps the given error with the designated value, which cannot be unmarshalled into. -func CantUnmarshalYAML(err error, v interface{}) error { - return errors.Wrapf(err, "can't unmarshal YAML into %T", v) -} - -// MarshalJSON calls json.Marshal and wraps any resulting errors. -func MarshalJSON(v interface{}) ([]byte, error) { - b, err := json.Marshal(v) - - return b, errors.Wrapf(err, "can't marshal JSON from %T", v) -} - -// UnmarshalJSON calls json.Unmarshal and wraps any resulting errors. -func UnmarshalJSON(data []byte, v interface{}) error { - return errors.Wrapf(json.Unmarshal(data, v), "can't unmarshal JSON into %T", v) -} diff --git a/pkg/types/uuid.go b/pkg/types/uuid.go deleted file mode 100644 index 02acbcdb1..000000000 --- a/pkg/types/uuid.go +++ /dev/null @@ -1,24 +0,0 @@ -package types - -import ( - "database/sql/driver" - "encoding" - "github.com/google/uuid" -) - -// UUID is like uuid.UUID, but marshals itself binarily (not like xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx) in SQL context. -type UUID struct { - uuid.UUID -} - -// Value implements driver.Valuer. -func (uuid UUID) Value() (driver.Value, error) { - return uuid.UUID[:], nil -} - -// Assert interface compliance. -var ( - _ encoding.TextUnmarshaler = (*UUID)(nil) - _ driver.Valuer = UUID{} - _ driver.Valuer = (*UUID)(nil) -) diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go deleted file mode 100644 index e7b7787a0..000000000 --- a/pkg/utils/utils.go +++ /dev/null @@ -1,167 +0,0 @@ -package utils - -import ( - "context" - "crypto/sha1" - "fmt" - "github.com/go-sql-driver/mysql" - "github.com/lib/pq" - "github.com/pkg/errors" - "golang.org/x/exp/utf8string" - "net" - "os" - "path/filepath" - "strings" - "time" -) - -// Timed calls the given callback with the time that has elapsed since the start. -// -// Timed should be installed by defer: -// -// func TimedExample(logger *zap.SugaredLogger) { -// defer utils.Timed(time.Now(), func(elapsed time.Duration) { -// logger.Debugf("Executed job in %s", elapsed) -// }) -// job() -// } -func Timed(start time.Time, callback func(elapsed time.Duration)) { - callback(time.Since(start)) -} - -// BatchSliceOfStrings groups the given keys into chunks of size count and streams them into a returned channel. -func BatchSliceOfStrings(ctx context.Context, keys []string, count int) <-chan []string { - batches := make(chan []string) - - go func() { - defer close(batches) - - for i := 0; i < len(keys); i += count { - end := i + count - if end > len(keys) { - end = len(keys) - } - - select { - case batches <- keys[i:end]: - case <-ctx.Done(): - return - } - } - }() - - return batches -} - -// IsContextCanceled returns whether the given error is context.Canceled. -func IsContextCanceled(err error) bool { - return errors.Is(err, context.Canceled) -} - -// Checksum returns the SHA-1 checksum of the data. -func Checksum(data interface{}) []byte { - var chksm [sha1.Size]byte - - switch data := data.(type) { - case string: - chksm = sha1.Sum([]byte(data)) - case []byte: - chksm = sha1.Sum(data) - default: - panic(fmt.Sprintf("Unable to create checksum for type %T", data)) - } - - return chksm[:] -} - -// IsDeadlock returns whether the given error signals serialization failure. -func IsDeadlock(err error) bool { - var e *mysql.MySQLError - if errors.As(err, &e) { - switch e.Number { - case 1205, 1213: - return true - default: - return false - } - } - - var pe *pq.Error - if errors.As(err, &pe) { - switch pe.Code { - case "40001", "40P01": - return true - } - } - - return false -} - -var ellipsis = utf8string.NewString("...") - -// Ellipsize shortens s to <=limit runes and indicates shortening by "...". -func Ellipsize(s string, limit int) string { - utf8 := utf8string.NewString(s) - switch { - case utf8.RuneCount() <= limit: - return s - case utf8.RuneCount() <= ellipsis.RuneCount(): - return ellipsis.String() - default: - return utf8.Slice(0, limit-ellipsis.RuneCount()) + ellipsis.String() - } -} - -// AppName returns the name of the executable that started this program (process). -func AppName() string { - exe, err := os.Executable() - if err != nil { - exe = os.Args[0] - } - - return filepath.Base(exe) -} - -// MaxInt returns the larger of the given integers. -func MaxInt(x, y int) int { - if x > y { - return x - } - - return y -} - -// IsUnixAddr indicates whether the given host string represents a Unix socket address. -// -// A host string that begins with a forward slash ('/') is considered Unix socket address. -func IsUnixAddr(host string) bool { - return strings.HasPrefix(host, "/") -} - -// JoinHostPort is like its equivalent in net., but handles UNIX sockets as well. -func JoinHostPort(host string, port int) string { - if IsUnixAddr(host) { - return host - } - - return net.JoinHostPort(host, fmt.Sprint(port)) -} - -// ChanFromSlice takes a slice of values and returns a channel from which these values can be received. -// This channel is closed after the last value was sent. -func ChanFromSlice[T any](values []T) <-chan T { - ch := make(chan T, len(values)) - for _, value := range values { - ch <- value - } - - close(ch) - - return ch -} - -// PrintErrorThenExit prints the given error to [os.Stderr] and exits with the specified error code. -func PrintErrorThenExit(err error, exitCode int) { - fmt.Fprintln(os.Stderr, err) - os.Exit(exitCode) -} diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go deleted file mode 100644 index b0ea54b8f..000000000 --- a/pkg/utils/utils_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package utils - -import ( - "github.com/stretchr/testify/require" - "testing" -) - -func TestChanFromSlice(t *testing.T) { - t.Run("Nil", func(t *testing.T) { - ch := ChanFromSlice[int](nil) - require.NotNil(t, ch) - requireClosedEmpty(t, ch) - }) - - t.Run("Empty", func(t *testing.T) { - ch := ChanFromSlice([]int{}) - require.NotNil(t, ch) - requireClosedEmpty(t, ch) - }) - - t.Run("NonEmpty", func(t *testing.T) { - ch := ChanFromSlice([]int{42, 23, 1337}) - require.NotNil(t, ch) - requireReceive(t, ch, 42) - requireReceive(t, ch, 23) - requireReceive(t, ch, 1337) - requireClosedEmpty(t, ch) - }) -} - -// requireReceive is a helper function to check if a value can immediately be received from a channel. -func requireReceive(t *testing.T, ch <-chan int, expected int) { - t.Helper() - - select { - case v, ok := <-ch: - require.True(t, ok, "receiving should return a value") - require.Equal(t, expected, v) - default: - require.Fail(t, "receiving should not block") - } -} - -// requireReceive is a helper function to check if the channel is closed and empty. -func requireClosedEmpty(t *testing.T, ch <-chan int) { - t.Helper() - - select { - case _, ok := <-ch: - require.False(t, ok, "receiving from channel should not return anything") - default: - require.Fail(t, "receiving should not block") - } -} diff --git a/pkg/version/version.go b/pkg/version/version.go deleted file mode 100644 index 250318c7e..000000000 --- a/pkg/version/version.go +++ /dev/null @@ -1,180 +0,0 @@ -package version - -import ( - "bufio" - "errors" - "fmt" - "os" - "runtime" - "runtime/debug" - "strconv" - "strings" -) - -type VersionInfo struct { - Version string - Commit string -} - -// Version determines version and commit information based on multiple data sources: -// - Version information dynamically added by `git archive` in the remaining to parameters. -// - A hardcoded version number passed as first parameter. -// - Commit information added to the binary by `go build`. -// -// It's supposed to be called like this in combination with setting the `export-subst` attribute for the corresponding -// file in .gitattributes: -// -// var Version = version.Version("1.0.0-rc2", "$Format:%(describe)$", "$Format:%H$") -// -// When exported using `git archive`, the placeholders are replaced in the file and this version information is -// preferred. Otherwise the hardcoded version is used and augmented with commit information from the build metadata. -func Version(version, gitDescribe, gitHash string) *VersionInfo { - const hashLen = 7 // Same truncation length for the commit hash as used by git describe. - - if !strings.HasPrefix(gitDescribe, "$") && !strings.HasPrefix(gitHash, "$") { - if strings.HasPrefix(gitDescribe, "%") { - // Only Git 2.32+ supports %(describe), older versions don't expand it but keep it as-is. - // Fall back to the hardcoded version augmented with the commit hash. - gitDescribe = version - - if len(gitHash) >= hashLen { - gitDescribe += "-g" + gitHash[:hashLen] - } - } - - return &VersionInfo{ - Version: gitDescribe, - Commit: gitHash, - } - } else { - commit := "" - - if info, ok := debug.ReadBuildInfo(); ok { - modified := false - - for _, setting := range info.Settings { - switch setting.Key { - case "vcs.revision": - commit = setting.Value - case "vcs.modified": - modified, _ = strconv.ParseBool(setting.Value) - } - } - - if len(commit) >= hashLen { - version += "-g" + commit[:hashLen] - - if modified { - version += "-dirty" - commit += " (modified)" - } - } - } - - return &VersionInfo{ - Version: version, - Commit: commit, - } - } -} - -// Print writes verbose version output to stdout. -func (v *VersionInfo) Print() { - fmt.Println("Icinga DB version:", v.Version) - fmt.Println() - - fmt.Println("Build information:") - fmt.Printf(" Go version: %s (%s, %s)\n", runtime.Version(), runtime.GOOS, runtime.GOARCH) - if v.Commit != "" { - fmt.Println(" Git commit:", v.Commit) - } - - if r, err := readOsRelease(); err == nil { - fmt.Println() - fmt.Println("System information:") - fmt.Println(" Platform:", r.Name) - fmt.Println(" Platform version:", r.DisplayVersion()) - } -} - -// osRelease contains the information obtained from the os-release file. -type osRelease struct { - Name string - Version string - VersionId string - BuildId string -} - -// DisplayVersion returns the most suitable version information for display purposes. -func (o *osRelease) DisplayVersion() string { - if o.Version != "" { - // Most distributions set VERSION - return o.Version - } else if o.VersionId != "" { - // Some only set VERSION_ID (Alpine Linux for example) - return o.VersionId - } else if o.BuildId != "" { - // Others only set BUILD_ID (Arch Linux for example) - return o.BuildId - } else { - return "(unknown)" - } -} - -// readOsRelease reads and parses the os-release file. -func readOsRelease() (*osRelease, error) { - for _, path := range []string{"/etc/os-release", "/usr/lib/os-release"} { - f, err := os.Open(path) - if err != nil { - if os.IsNotExist(err) { - continue // Try next path. - } else { - return nil, err - } - } - - o := &osRelease{ - Name: "Linux", // Suggested default as per os-release(5) man page. - } - - scanner := bufio.NewScanner(f) - for scanner.Scan() { - line := scanner.Text() - if strings.HasPrefix(line, "#") { - continue // Ignore comment. - } - - parts := strings.SplitN(line, "=", 2) - if len(parts) != 2 { - continue // Ignore empty or possibly malformed line. - } - - key := parts[0] - val := parts[1] - - // Unquote strings. This isn't fully compliant with the specification which allows using some shell escape - // sequences. However, typically quotes are only used to allow whitespace within the value. - if len(val) >= 2 && (val[0] == '"' || val[0] == '\'') && val[0] == val[len(val)-1] { - val = val[1 : len(val)-1] - } - - switch key { - case "NAME": - o.Name = val - case "VERSION": - o.Version = val - case "VERSION_ID": - o.VersionId = val - case "BUILD_ID": - o.BuildId = val - } - } - if err := scanner.Err(); err != nil { - return nil, err - } - - return o, nil - } - - return nil, errors.New("os-release file not found") -}