Skip to content

Commit

Permalink
Improve performance by caching IPs in memory
Browse files Browse the repository at this point in the history
Instead of reading the file on every request (very slow), the list of
IPs is loaded through various triggers and kept in memory. On each
request, it is now a much faster lookup in memory. To ensure the list
remains up to date, fsnotify is used to watch the file for changes and
reload it.
  • Loading branch information
Javex committed May 25, 2024
1 parent 72787a8 commit dbfadff
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 57 deletions.
197 changes: 197 additions & 0 deletions banlist.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
package caddy_fail2ban

import (
"bufio"
"errors"
"fmt"
"os"
"path/filepath"
"sync"

"github.com/fsnotify/fsnotify"
"go.uber.org/zap"
)

type Banlist struct {
bannedIps []string
shutdown chan bool
lock *sync.RWMutex
logger *zap.Logger
banfile *string
reload chan chan bool
reloadSubs []chan bool
}

func NewBanlist(logger *zap.Logger, banfile *string) Banlist {
banlist := Banlist{
bannedIps: make([]string, 0),
shutdown: make(chan bool),
lock: new(sync.RWMutex),
logger: logger,
banfile: banfile,
reload: make(chan chan bool),
reloadSubs: make([]chan bool, 0),
}
return banlist
}

func (b *Banlist) Start() {
go b.monitorBannedIps()
}

func (b *Banlist) IsBanned(remote_ip string) bool {
b.lock.RLock()
defer b.lock.RUnlock()

for _, ip := range b.bannedIps {
b.logger.Debug("Checking IP", zap.String("ip", ip), zap.String("remote_ip", remote_ip))
if ip == remote_ip {
return true
}
}
return false
}

func (b *Banlist) Stop() error {
if b.shutdown != nil {
b.shutdown <- true
_, ok := <-b.shutdown
if ok {
b.logger.Error("Failed to shutdown monitor goroutine")
return errors.New("shutdown of monitor failed")
}
}
return nil
}

func (b *Banlist) Reload() {
resp := make(chan bool)

b.reload <- resp
<-resp
}

func (b *Banlist) monitorBannedIps() {
b.logger.Info("Starting monitor for banned IPs")
defer func() {
b.logger.Info("Shutting down monitor for banned IPs")
close(b.shutdown)
}()

// Load initial list
err := b.loadBannedIps()
if err != nil {
b.logger.Error("Error loading initial list of banned IPs", zap.Error(err))
return
}

watcher, err := fsnotify.NewWatcher()
if err != nil {
b.logger.Error("Error creating monitor", zap.Error(err))
return
}
defer watcher.Close()

// Watch the directory that the banfile is in as sometimes files can be
// written to by replacement (see https://pkg.go.dev/github.com/fsnotify/fsnotify#readme-watching-a-file-doesn-t-work-well)
err = watcher.Add(filepath.Dir(*b.banfile))
if err != nil {
b.logger.Error("Error monitoring banfile", zap.Error(err), zap.String("banfile", *b.banfile))
}

for {
select {
case resp := <-b.reload:
// Trigger reload of banned IPs
err = b.loadBannedIps()
if err != nil {
b.logger.Error("Error when trying to explicitly reloading list of banned IPs", zap.Error(err))
return
}
b.logger.Debug("Banlist reloaded")
resp <- true
case err, ok := <-watcher.Errors:
if !ok {
b.logger.Error("Error channel closed unexpectedly, stopping monitor")
return
}
b.logger.Error("Error from fsnotify", zap.Error(err))
case event, ok := <-watcher.Events:
if !ok {
b.logger.Error("Watcher closed unexpectedly, stopping monitor")
return
}
// We get events for the whole directory but only want to do work if the
// changed file is our banfile
if event.Has(fsnotify.Write) && event.Name == *b.banfile {
b.logger.Debug("File has changed, reloading banned IPs")
err = b.loadBannedIps()
if err != nil {
b.logger.Error("Error when trying to reload banned IPs because of inotify event", zap.Error(err))
return
}
}
case <-b.shutdown:
// Receive signal to finish
b.logger.Debug("Received shutdown signal")
return
}
}
}

// Provide a channel that will receive a boolean true value whenever the list
// of banned IPs has been reloaded. Mostly useful for tests so they can wait
// for the inotify event rather than sleep
func (b *Banlist) subscribeToReload(notify chan bool) {
b.reloadSubs = append(b.reloadSubs, notify)
}

// loadBannedIps loads list of banned IPs from file on disk and notifies
// subscribers in case it was successful
func (b *Banlist) loadBannedIps() error {
bannedIps, err := b.getBannedIps()
if err != nil {
b.logger.Error("Error getting list of banned IPs")
return err
} else {
b.lock.Lock()
b.bannedIps = bannedIps
b.lock.Unlock()
for _, n := range b.reloadSubs {
n <- true
}
return nil
}
}

func (b *Banlist) getBannedIps() ([]string, error) {

// Open banfile
// Try to open file
banfileHandle, err := os.Open(*b.banfile)
if err != nil {
b.logger.Info("Creating new file since Open failed", zap.String("banfile", *b.banfile), zap.Error(err))
// Try to create new file, maybe the file didn't exist yet
banfileHandle, err = os.Create(*b.banfile)
if err != nil {
b.logger.Error("Error creating banfile", zap.String("banfile", *b.banfile), zap.Error(err))
return nil, fmt.Errorf("cannot open or create banfile: %v", err)
}
}
defer banfileHandle.Close()

// read banned IPs
bannedIps := make([]string, 0)
scanner := bufio.NewScanner(banfileHandle)
for scanner.Scan() {
line := scanner.Text()
b.logger.Debug("Adding banned IP to list", zap.String("banned_addr", line))
bannedIps = append(bannedIps, line)
}

if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error parsing banfile: %v", err)
}

return bannedIps, nil
}
67 changes: 12 additions & 55 deletions fail2ban.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package caddy_fail2ban

import (
"bufio"
"fmt"
"net"
"net/http"
"os"

"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
Expand All @@ -15,21 +13,15 @@ import (

func init() {
caddy.RegisterModule(Fail2Ban{})
// httpcaddyfile.RegisterHandlerDirective("visitor_ip", parseCaddyfile)
}

// Fail2Ban implements an HTTP handler that writes the
// visitor's IP address to a file or stream.
// Fail2Ban implements an HTTP handler that checks a specified file for banned
// IPs and matches if they are found
type Fail2Ban struct {
// The file or stream to write to. Can be "stdout"
// or "stderr".
Output string `json:"output,omitempty"`

// w io.Writer

Banfile string `json:"banfile"`

logger *zap.Logger
logger *zap.Logger
banlist Banlist
}

// CaddyModule returns the Caddy module information.
Expand All @@ -43,39 +35,13 @@ func (Fail2Ban) CaddyModule() caddy.ModuleInfo {
// Provision implements caddy.Provisioner.
func (m *Fail2Ban) Provision(ctx caddy.Context) error {
m.logger = ctx.Logger()
m.banlist = NewBanlist(m.logger, &m.Banfile)
m.banlist.Start()
return nil
}

func (m *Fail2Ban) getBannedIps() ([]string, error) {

// Open banfile
// Try to open file
banfileHandle, err := os.Open(m.Banfile)
if err != nil {
m.logger.Info("Creating new file at since Open failed", zap.String("banfile", m.Banfile), zap.Error(err))
// Try to create new file, maybe the file didn't exist yet
banfileHandle, err = os.Create(m.Banfile)
if err != nil {
m.logger.Error("Error creating banfile", zap.String("banfile", m.Banfile), zap.Error(err))
return nil, fmt.Errorf("cannot open or create banfile: %v", err)
}
}
defer banfileHandle.Close()

// read banned IPs
bannedIps := make([]string, 0)
scanner := bufio.NewScanner(banfileHandle)
for scanner.Scan() {
line := scanner.Text()
m.logger.Debug("Adding banned IP to list", zap.String("banned_addr", line))
bannedIps = append(bannedIps, line)
}

if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error parsing banfile: %v", err)
}

return bannedIps, nil
func (m *Fail2Ban) Cleanup() error {
return m.banlist.Stop()
}

// Validate implements caddy.Validator.
Expand All @@ -101,21 +67,11 @@ func (m *Fail2Ban) Match(req *http.Request) bool {
return true
}

// check IPs, too
bannedIps, err := m.getBannedIps()
if err != nil {
m.logger.Error("error getting banned IPs", zap.Error(err))
// Deny by default
if m.banlist.IsBanned(remote_ip) == true {
m.logger.Info("banned IP", zap.String("remote_addr", remote_ip))
return true
}

for _, ip := range bannedIps {
if ip == remote_ip {
m.logger.Debug("banned IP", zap.String("remote_addr", remote_ip))
return true
}
}

m.logger.Debug("received request", zap.String("remote_addr", remote_ip))
return false
}
Expand Down Expand Up @@ -146,7 +102,8 @@ func (m *Fail2Ban) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {

// Interface guards
var (
_ caddy.Provisioner = (*Fail2Ban)(nil)
_ caddy.Provisioner = (*Fail2Ban)(nil)
_ caddy.CleanerUpper = (*Fail2Ban)(nil)
// _ caddy.Validator = (*Fail2Ban)(nil)
_ caddyhttp.RequestMatcher = (*Fail2Ban)(nil)
_ caddyfile.Unmarshaler = (*Fail2Ban)(nil)
Expand Down
23 changes: 22 additions & 1 deletion fail2ban_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,20 @@ func TestModule(t *testing.T) {
if err != nil {
t.Errorf("error provisioning: %v", err)
}
defer func() {
err := m.Cleanup()
if err != nil {
t.Fatalf("unexpected error on cleanup: %v", err)
}
}()

req := httptest.NewRequest("GET", "https://127.0.0.1", strings.NewReader(""))

if got, exp := m.Match(req), false; got != exp {
t.Errorf("unexpected match. got: %t, exp: %t", got, exp)
}

bannedIps, err := m.getBannedIps()
bannedIps, err := m.banlist.getBannedIps()
if err != nil {
t.Errorf("error loading banned ips: %v", err)
}
Expand Down Expand Up @@ -87,6 +93,12 @@ func TestHeaderBan(t *testing.T) {
if err != nil {
t.Errorf("error provisioning: %v", err)
}
defer func() {
err := m.Cleanup()
if err != nil {
t.Fatalf("unexpected error on cleanup: %v", err)
}
}()

req := httptest.NewRequest("GET", "https://127.0.0.1", strings.NewReader(""))
req.Header.Add("X-Caddy-Ban", "1")
Expand Down Expand Up @@ -115,6 +127,12 @@ func TestBanIp(t *testing.T) {
if err != nil {
t.Errorf("error provisioning: %v", err)
}
defer func() {
err := m.Cleanup()
if err != nil {
t.Fatalf("unexpected error on cleanup: %v", err)
}
}()

req := httptest.NewRequest("GET", "https://127.0.0.1", strings.NewReader(""))
req.RemoteAddr = "127.0.0.1:1337"
Expand All @@ -124,7 +142,10 @@ func TestBanIp(t *testing.T) {
}

// ban IP
reloadEvent := make(chan bool)
m.banlist.subscribeToReload(reloadEvent)
os.WriteFile(fail2banFile, []byte("127.0.0.1"), 0644)
<-reloadEvent

req = httptest.NewRequest("GET", "https://127.0.0.1", strings.NewReader(""))
req.RemoteAddr = "127.0.0.1:1337"
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.21.0

require (
github.com/caddyserver/caddy/v2 v2.7.6
github.com/fsnotify/fsnotify v1.7.0
go.uber.org/zap v1.25.0
)

Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv
github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4=
github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2rbfLwlschooIH4+wKKDR4Pdxhh+TRoA20=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/go-kit/kit v0.4.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
Expand Down
Loading

0 comments on commit dbfadff

Please sign in to comment.