Skip to content

Commit

Permalink
Check banned IPs via channel
Browse files Browse the repository at this point in the history
Use the more idiomatic way of communicating rather than sharing memory:
The goroutine handles answering queries rather than using a lock to
share memory about the banlist.
  • Loading branch information
Javex committed May 25, 2024
1 parent 28d7ff2 commit ca78ae9
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 20 deletions.
63 changes: 43 additions & 20 deletions banlist.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,22 @@ import (
"fmt"
"os"
"path/filepath"
"sync"

"github.com/caddyserver/caddy/v2"
"github.com/fsnotify/fsnotify"
"go.uber.org/zap"
)

type banQuery struct {
response chan bool
ip string
}

type Banlist struct {
ctx caddy.Context
bannedIps []string
shutdown chan bool
lock *sync.RWMutex
queries chan banQuery
logger *zap.Logger
banfile *string
reload chan chan bool
Expand All @@ -25,13 +29,12 @@ type Banlist struct {

func NewBanlist(ctx caddy.Context, logger *zap.Logger, banfile *string) Banlist {
banlist := Banlist{
ctx: ctx,
bannedIps: make([]string, 0),
lock: new(sync.RWMutex),
logger: logger,
banfile: banfile,
reload: make(chan chan bool),
reloadSubs: make([]chan bool, 0),
ctx: ctx,
bannedIps: make([]string, 0),
queries: make(chan banQuery),
logger: logger,
banfile: banfile,
reload: make(chan chan bool),
}
return banlist
}
Expand All @@ -41,16 +44,15 @@ func (b *Banlist) Start() {
}

func (b *Banlist) IsBanned(remote_ip string) bool {
b.lock.RLock()
defer b.lock.RUnlock()

for _, ip := range b.bannedIps {
b.logger.Debug("Checking IP", zap.String("ip", ip), zap.String("remote_ip", remote_ip))
if ip == remote_ip {
return true
}
response := make(chan bool)
query := banQuery{
response,
remote_ip,
}
return false
b.queries <- query
isBanned := <-response
close(response)
return isBanned
}

func (b *Banlist) Reload() {
Expand Down Expand Up @@ -98,13 +100,19 @@ func (b *Banlist) monitorBannedIps() {
}
b.logger.Debug("Banlist reloaded")
resp <- true
case query := <-b.queries:
// Respond to query whether an IP has been banned
b.logger.Debug("Handling ban query", zap.String("remote_ip", query.ip))
b.handleQuery(query)
case err, ok := <-watcher.Errors:
// Handle errors from fsnotify
if !ok {
b.logger.Error("Error channel closed unexpectedly, stopping monitor")
return
}
b.logger.Error("Error from fsnotify", zap.Error(err))
case event, ok := <-watcher.Events:
// Respond to changed file events from fsnotify
if !ok {
b.logger.Error("Watcher closed unexpectedly, stopping monitor")
return
Expand All @@ -120,12 +128,25 @@ func (b *Banlist) monitorBannedIps() {
}
}
case <-b.ctx.Done():
// Caddy will close the context when it's time to shut down
b.logger.Debug("Context finished, shutting down")
return
}
}
}

func (b *Banlist) handleQuery(query banQuery) {
remote_ip := query.ip
for _, ip := range b.bannedIps {
b.logger.Debug("Checking IP", zap.String("ip", ip), zap.String("remote_ip", remote_ip))
if ip == remote_ip {
query.response <- true
return
}
}
query.response <- false
}

// Provide a channel that will receive a boolean true value whenever the list
// of banned IPs has been reloaded. Mostly useful for tests so they can wait
// for the inotify event rather than sleep
Expand All @@ -141,12 +162,14 @@ func (b *Banlist) loadBannedIps() error {
b.logger.Error("Error getting list of banned IPs")
return err
} else {
b.lock.Lock()
b.bannedIps = bannedIps
b.lock.Unlock()
for _, n := range b.reloadSubs {
n <- true
}
// only respond once then clear subs, otherwise further attempts might
// block as the receiver only reads one event rather than constantly
// draining it.
b.reloadSubs = nil
return nil
}
}
Expand Down
5 changes: 5 additions & 0 deletions fail2ban_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ func TestHeaderBan(t *testing.T) {
if got, exp := m.Match(req), true; got != exp {
t.Errorf("unexpected match. got: %t, exp: %t", got, exp)
}

// Trigger explicit reload just to give the goroutine enough time to spin up,
// otherwise the defer above will delete the temporary directory before it
// had time to initialise
m.banlist.Reload()
}

func TestBanIp(t *testing.T) {
Expand Down

0 comments on commit ca78ae9

Please sign in to comment.