From 6e8615a7ebb6487a9ae734a07cf68cc1bd65c04c Mon Sep 17 00:00:00 2001 From: Andrew LeFevre Date: Sun, 18 Jun 2023 20:22:10 -0400 Subject: [PATCH] periodically fix missing or modified rules --- cmd/whalewall/main.go | 8 +++++++- create.go | 8 ++++++-- db.go | 2 +- manager.go | 12 ++++++++++-- sync.go | 35 +++++++++++++++++++++++++++++++++++ whalewall_test.go | 10 +++++----- 6 files changed, 64 insertions(+), 11 deletions(-) diff --git a/cmd/whalewall/main.go b/cmd/whalewall/main.go index 6e4bb15..e3a4c37 100644 --- a/cmd/whalewall/main.go +++ b/cmd/whalewall/main.go @@ -35,6 +35,7 @@ func mainRetCode() int { debugLogs := flag.Bool("debug", false, "enable debug logging") logPath := flag.String("l", "stdout", "path to log to") timeout := flag.Duration("t", 10*time.Second, "timeout for Docker API requests") + watchInterval := flag.Duration("i", time.Minute, "interval to check created container rules") displayVersion := flag.Bool("version", false, "print version and build information and exit") flag.Parse() @@ -49,6 +50,11 @@ func mainRetCode() int { return 0 } + if *watchInterval <= 0 { + log.Println("-i must be greater than 0") + return 1 + } + // build logger logCfg := zap.NewProductionConfig() logCfg.OutputPaths = []string{*logPath} @@ -92,7 +98,7 @@ func mainRetCode() int { ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer cancel() - r, err := whalewall.NewRuleManager(ctx, logger, sqliteFile, *timeout) + r, err := whalewall.NewRuleManager(ctx, logger, sqliteFile, *timeout, *watchInterval) if err != nil { logger.Error("error initializing", zap.Error(err)) } diff --git a/create.go b/create.go index 282027b..35ca503 100644 --- a/create.go +++ b/create.go @@ -76,7 +76,11 @@ func (r *RuleManager) createContainerRules(ctx context.Context, container types. contName := stripName(container.Name) logger := r.logger.With(zap.String("container.id", container.ID[:12]), zap.String("container.name", contName)) - logger.Info("creating rules", zap.Bool("container.is_new", isNew)) + if isNew { + logger.Info("creating rules", zap.Bool("container.is_new", isNew)) + } else { + logger.Debug("watching rules", zap.Bool("container.is_new", isNew)) + } // check that network settings are valid if container.NetworkSettings == nil { @@ -332,7 +336,7 @@ func (r *RuleManager) createContainerRules(ctx context.Context, container types. logger.Debug("adding to database") - if err := r.addContainer(ctx, tx, container.ID, contName, service, addrs, estContainers); err != nil { + if err := r.addContainerInfo(ctx, tx, container.ID, contName, service, addrs, estContainers); err != nil { return fmt.Errorf("error adding container information to database: %w", err) } diff --git a/db.go b/db.go index 1597693..5b81bfc 100644 --- a/db.go +++ b/db.go @@ -20,7 +20,7 @@ func (r *RuleManager) containerExists(ctx context.Context, db database.Querier, return exists == 1, nil } -func (r *RuleManager) addContainer(ctx context.Context, tx database.TX, id, name, service string, addrs map[string][]byte, estContainers map[string]struct{}) error { +func (r *RuleManager) addContainerInfo(ctx context.Context, tx database.TX, id, name, service string, addrs map[string][]byte, estContainers map[string]struct{}) error { for _, addr := range addrs { err := tx.AddContainerAddr(ctx, addr, id) if err != nil { diff --git a/manager.go b/manager.go index a8823ce..204208c 100644 --- a/manager.go +++ b/manager.go @@ -54,6 +54,8 @@ type RuleManager struct { containerTracker *container.Tracker + watchInterval time.Duration + createCh chan containerDetails deleteCh chan string @@ -70,7 +72,7 @@ type containerDetails struct { isNew bool } -func NewRuleManager(ctx context.Context, logger *zap.Logger, dbFile string, timeout time.Duration) (*RuleManager, error) { +func NewRuleManager(ctx context.Context, logger *zap.Logger, dbFile string, timeout, watchInterval time.Duration) (*RuleManager, error) { r := RuleManager{ stopping: make(chan struct{}), done: make(chan struct{}), @@ -89,6 +91,7 @@ func NewRuleManager(ctx context.Context, logger *zap.Logger, dbFile string, time return nftables.New() }, containerTracker: container.NewTracker(logger), + watchInterval: watchInterval, createCh: make(chan containerDetails), deleteCh: make(chan string), } @@ -126,7 +129,12 @@ func (r *RuleManager) Start(ctx context.Context) error { r.logger.Error("error syncing containers", zap.Error(err)) } - r.wg.Add(1) + r.wg.Add(2) + go func() { + defer r.wg.Done() + + r.watchContainers(ctx) + }() go func() { defer r.wg.Done() diff --git a/sync.go b/sync.go index 6876317..6133e2b 100644 --- a/sync.go +++ b/sync.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "slices" + "time" "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/filters" @@ -74,6 +75,40 @@ func (r *RuleManager) syncContainers(ctx context.Context) error { return nil } +// watchContainers periodically checks that container rules created by +// whalewall haven't been deleted, modified or added to and fixes them +// if necessary. +func (r *RuleManager) watchContainers(ctx context.Context) { + ticker := time.NewTicker(r.watchInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + conts, err := r.db.GetContainers(ctx) + if err != nil { + r.logger.Error("error getting containers from database", zap.Error(err)) + continue + } + for _, c := range conts { + container, err := r.dockerCli.ContainerInspect(ctx, c.ID) + if err != nil { + r.logger.Error("error inspecting container", zap.String("container.id", c.ID[:12]), zap.Error(err)) + continue + } + r.createCh <- containerDetails{ + container: container, + isNew: false, + } + } + case <-ctx.Done(): + return + case <-r.stopping: + return + } + } +} + func whalewallEnabled(labels map[string]string) (bool, error) { e, ok := labels[enabledLabel] if !ok { diff --git a/whalewall_test.go b/whalewall_test.go index 7344873..cbab03f 100644 --- a/whalewall_test.go +++ b/whalewall_test.go @@ -270,7 +270,7 @@ func startFunc(t *testing.T, is *is.I, tempDir string) func() { logger.Info("starting whalewall") ctx, cancel := context.WithCancel(context.Background()) dbFile := filepath.Join(tempDir, "db.sqlite") - r, err := NewRuleManager(ctx, logger, dbFile, defaultTimeout) + r, err := NewRuleManager(ctx, logger, dbFile, defaultTimeout, time.Minute) is.NoErr(err) err = r.Start(ctx) is.NoErr(err) @@ -2335,7 +2335,7 @@ mapped_ports: is := is.New(t) dbFile := filepath.Join(t.TempDir(), "db.sqlite") - r, err := NewRuleManager(context.Background(), logger, dbFile, defaultTimeout) + r, err := NewRuleManager(context.Background(), logger, dbFile, defaultTimeout, defaultTimeout) is.NoErr(err) var dockerCli *mockDockerClient @@ -2559,7 +2559,7 @@ output: } dbFile := filepath.Join(t.TempDir(), "db.sqlite") - r, err := NewRuleManager(context.Background(), logger, dbFile, defaultTimeout) + r, err := NewRuleManager(context.Background(), logger, dbFile, defaultTimeout, time.Minute) is.NoErr(err) dockerCli := newMockDockerClient(nil) @@ -2701,7 +2701,7 @@ output: } dbFile := filepath.Join(t.TempDir(), "db.sqlite") - r, err := NewRuleManager(context.Background(), logger, dbFile, defaultTimeout) + r, err := NewRuleManager(context.Background(), logger, dbFile, defaultTimeout, time.Minute) is.NoErr(err) dockerCli := newMockDockerClient(nil) @@ -2838,7 +2838,7 @@ output: is.NoErr(err) dbFile := filepath.Join(t.TempDir(), "db.sqlite") - r, err := NewRuleManager(context.Background(), logger, dbFile, defaultTimeout) + r, err := NewRuleManager(context.Background(), logger, dbFile, defaultTimeout, time.Minute) is.NoErr(err) // configure database to pause before committing so we can cancel