diff --git a/filter.go b/filter.go index 7ad646f..6e995b6 100644 --- a/filter.go +++ b/filter.go @@ -246,7 +246,7 @@ func createFilter(ctx context.Context, logger *zap.Logger, opts *FilterOptions, f.additionalHostnames = timedcache.New[string](filterLogger, false) nf4, nf6, err := openNfQueues(ctx, filterLogger, opts.TrafficQueue, newEnforcer, func(ipv6 bool) nfqueue.HookFunc { - return newGenericCallback(&f, ipv6) + return newGenericCallback(ctx, &f, ipv6) }) if err != nil { return nil, fmt.Errorf("error starting traffic nfqueues: %w", err) @@ -823,7 +823,7 @@ func newDNSResponseCallback(f *FilterManager, ipv6 bool) nfqueue.HookFunc { } } -func newGenericCallback(f *filter, ipv6 bool) nfqueue.HookFunc { +func newGenericCallback(ctx context.Context, f *filter, ipv6 bool) nfqueue.HookFunc { var queueNum uint16 if !ipv6 { queueNum = f.opts.TrafficQueue.IPv4 @@ -909,7 +909,7 @@ func newGenericCallback(f *filter, ipv6 bool) nfqueue.HookFunc { // validate that either the source or destination IP is allowed var verdict int - allowed, err := f.validateIPs(logger, src, dst) + allowed, err := f.validateIPs(ctx, logger, src, dst) if err != nil { logger.Error("error validating IPs", zap.Stringer("conn.src", src), zap.Stringer("conn.dst", dst), zap.NamedError("error", err)) verdict = nfqueue.NfDrop @@ -931,7 +931,7 @@ func newGenericCallback(f *filter, ipv6 bool) nfqueue.HookFunc { } } -func (f *filter) validateIPs(logger *zap.Logger, src, dst netip.Addr) (bool, error) { +func (f *filter) validateIPs(ctx context.Context, logger *zap.Logger, src, dst netip.Addr) (bool, error) { // check if the destination IP is allowed first, as most likely // we are validating an outbound connection if f.allowedIPs.EntryExists(dst) { @@ -948,7 +948,7 @@ func (f *filter) validateIPs(logger *zap.Logger, src, dst netip.Addr) (bool, err // preform reverse IP lookups on the destination and then source // IPs only if the IPs are not private if !dst.IsPrivate() { - allowed, err := f.lookupAndValidateIP(logger, dst) + allowed, err := f.lookupAndValidateIP(ctx, logger, dst) if err != nil { return false, err } @@ -958,15 +958,14 @@ func (f *filter) validateIPs(logger *zap.Logger, src, dst netip.Addr) (bool, err } if !src.IsPrivate() { - return f.lookupAndValidateIP(logger, src) + return f.lookupAndValidateIP(ctx, logger, src) } return false, nil } -func (f *filter) lookupAndValidateIP(logger *zap.Logger, ip netip.Addr) (bool, error) { - // TODO: build from top-level context - ctx, cancel := context.WithTimeout(context.Background(), dnsQueryTimeout) +func (f *filter) lookupAndValidateIP(ctx context.Context, logger *zap.Logger, ip netip.Addr) (bool, error) { + ctx, cancel := context.WithTimeout(ctx, dnsQueryTimeout) defer cancel() logger.Info("preforming reverse IP lookup", zap.Stringer("ip", ip)) diff --git a/timedcache/timed_cache.go b/timedcache/timed_cache.go index 770aa08..2f559f8 100644 --- a/timedcache/timed_cache.go +++ b/timedcache/timed_cache.go @@ -92,10 +92,13 @@ func (t *TimedCache[T]) AddEntry(entry T, ttl time.Duration) { case <-timer.C: running = false case s := <-status: - if s == reset { + switch s { + case reset: // wait until timer is finished resetting <-status - } else if s == stop { + case start: + // the timer has started, wait for another status + case stop: return } }