Skip to content

Commit

Permalink
compute net mask on initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
joecorall committed Feb 28, 2025
1 parent 85b5d50 commit 7a58dd2
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 64 deletions.
79 changes: 57 additions & 22 deletions ci/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"log/slog"
"math/rand"
"net"
Expand All @@ -18,8 +17,10 @@ import (
cp "github.com/libops/captcha-protect"
)

var rateLimit = 5
var exemptIps []*net.IPNet
var (
rateLimit = 5
exemptIps []*net.IPNet
)

const numIPs = 100
const parallelism = 10
Expand All @@ -34,11 +35,7 @@ func main() {
"fc00::/8",
}
for _, ip := range _ips {
parsedIp, err := cp.ParseCIDR(ip)
if err != nil {
slog.Error("error parsing cidr", "ip", ip, "err", err)
os.Exit(1)
}
parsedIp := parseCIDR(ip)
exemptIps = append(exemptIps, parsedIp)
}

Expand Down Expand Up @@ -86,10 +83,23 @@ func generateUniquePublicIPs(n int) []string {
ipSet := make(map[string]struct{})
var ips []string
config := cp.CreateConfig()
bc := &cp.CaptchaProtect{}
bc.SetExemptIps(exemptIps)
err := bc.SetIpv4Mask(16)
if err != nil {
slog.Error("unable to set ipv4 mask")
os.Exit(1)
}

err = bc.SetIpv6Mask(64)
if err != nil {
slog.Error("unable to set ipv6 mask")
os.Exit(1)
}

for len(ips) < n {
ip := randomPublicIP(config)
ip, ipRange := cp.ParseIp(ip, 16, 64)
ip, ipRange := bc.ParseIp(ip)
if _, exists := ipSet[ipRange]; !exists {
ipSet[ipRange] = struct{}{}
ips = append(ips, ip)
Expand Down Expand Up @@ -142,7 +152,9 @@ func runParallelChecks(ips []string, rateLimit int) {
fmt.Printf("Checking %s\n", ip)
output := httpRequest(ip)
if output != "" {
log.Fatalf("Unexpected output for %s: %s", ip, output)
slog.Error("Unexpected output", "ip", ip, "output", output)
os.Exit(1)

}
}(ip)
}
Expand All @@ -157,7 +169,8 @@ func ensureRedirect(ips []string) {
output := httpRequest(ip)

if output != expectedRedirectURL {
log.Fatalf("Unexpected output for %s: %s", ip, output)
slog.Error("Unexpected output", "ip", ip, "output", output)
os.Exit(1)
}

fmt.Printf("Got a redirect! %s\n", output)
Expand All @@ -177,12 +190,15 @@ func httpRequest(ip string) string {

req, err := http.NewRequest("GET", "http://localhost", nil)
if err != nil {
log.Fatalf("Failed to create request: %v", err)
slog.Error("Failed to create request", "err", err)
os.Exit(1)
}
req.Header.Set("X-Forwarded-For", ip)
resp, err := client.Do(req)
if err != nil {
log.Fatalf("Request failed: %v", err)
slog.Error("Request failed", "err", err)
os.Exit(1)

}
defer resp.Body.Close()

Expand All @@ -192,7 +208,9 @@ func httpRequest(ip string) string {
if err == http.ErrNoLocation {
return ""
}
log.Fatalf("Failed to get redirect URL: %v", err)
slog.Error("Failed to get redirect URL", "err", err)
os.Exit(1)

}

return strings.TrimSpace(location.String())
Expand All @@ -211,38 +229,55 @@ func runCommand(name string, args ...string) {
cmd.Env = append(cmd.Env, fmt.Sprintf("TRAEFIK_TAG=%s", tt))
}
if err := cmd.Run(); err != nil {
log.Fatalf("Command failed: %v", err)
slog.Error("Command failed", "err", err)
os.Exit(1)
}
}

func checkStateReload() {
resp, err := http.Get("http://localhost/captcha-protect/stats")
if err != nil {
log.Fatalf("Failed to make GET request: %v", err)
slog.Error("Failed to make GET request", "err", err)
os.Exit(1)
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
log.Fatalf("Failed to read response body: %v", err)
slog.Error("Failed to read response body", "err", err)
os.Exit(1)

}
var jsonResponse map[string]interface{}
err = json.Unmarshal(body, &jsonResponse)
if err != nil {
log.Fatalf("Failed to unmarshal JSON: %v", err)
slog.Error("Failed to unmarshal JSON", "err", err)
os.Exit(1)

}
bots, exists := jsonResponse["bots"]
if !exists {
log.Fatalf("Key 'bots' not found in JSON response")
slog.Error("Key 'bots' not found in JSON response")
os.Exit(1)
}
botsMap, ok := bots.(map[string]interface{})
if !ok {
log.Fatalf("'bots' is not an array")
slog.Error("'bots' is not an array")
os.Exit(1)
}

if len(botsMap) != numIPs {
log.Fatalf("Expected %d bots, but got %d", numIPs, len(botsMap))
slog.Error("Unexpected number of bots", "expected", numIPs, "received", len(botsMap))
os.Exit(1)
}

log.Println("State reloaded successfully!")
slog.Info("State reloaded successfully!")
}

func parseCIDR(cidr string) *net.IPNet {
_, block, err := net.ParseCIDR(cidr)
if err != nil {
slog.Error("Failed to parse CIDR", "cidr", cidr, "err", err)
}
return block
}
59 changes: 41 additions & 18 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ type CaptchaProtect struct {
captchaConfig CaptchaConfig
exemptIps []*net.IPNet
tmpl *template.Template
ipv4Mask net.IPMask
ipv6Mask net.IPMask
}

type CaptchaConfig struct {
Expand Down Expand Up @@ -112,13 +114,6 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
}
logLevel.Set(level)

if config.IPv4SubnetMask < 8 || config.IPv4SubnetMask > 32 {
return nil, fmt.Errorf("invalid ipv4 mask: %d. Must be between 8 and 32", config.IPv4SubnetMask)
}
if config.IPv6SubnetMask < 8 || config.IPv6SubnetMask > 128 {
return nil, fmt.Errorf("invalid ipv6 mask: %d. Must be between 8 and 128", config.IPv6SubnetMask)
}

expiration := time.Duration(config.Window) * time.Second
log.Debug("Captcha config", "config", config)

Expand Down Expand Up @@ -172,6 +167,16 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
tmpl: tmpl,
}

err = bc.SetIpv4Mask(config.IPv4SubnetMask)
if err != nil {
return nil, err
}

err = bc.SetIpv6Mask(config.IPv6SubnetMask)
if err != nil {
return nil, err
}

// set the captcha config based on the provider
// thanks to https://github.com/maxlerebourg/crowdsec-bouncer-traefik-plugin/blob/4708d76854c7ae95fa7313c46fbe21959be2fff1/pkg/captcha/captcha.go#L39-L55
// for the struct/idea
Expand Down Expand Up @@ -444,36 +449,50 @@ func (bc *CaptchaProtect) getClientIP(req *http.Request) (string, string) {
ip = host
}

return ParseIp(ip, bc.config.IPv4SubnetMask, bc.config.IPv6SubnetMask)
return bc.ParseIp(ip)
}

func ParseIp(ip string, ipv4Mask, ipv6Mask int) (string, string) {
func (bc *CaptchaProtect) ParseIp(ip string) (string, string) {
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return ip, ip
}

// For IPv4 addresses
if parsedIP.To4() != nil {
mask := net.CIDRMask(ipv4Mask, 32)
subnet := parsedIP.Mask(mask)
subnet := parsedIP.Mask(bc.ipv4Mask)
return ip, subnet.String()
}

// For IPv6 addresses
if parsedIP.To16() != nil {
ipParts := strings.Split(ip, ":")
// Calculate the number of hextets required.
required := ipv6Mask / 16
if len(ipParts) >= required {
subnet := strings.Join(ipParts[:required], ":")
return ip, subnet
}
subnet := parsedIP.Mask(bc.ipv6Mask)
return ip, subnet.String()
}

log.Warn("Unknown ip version", "ip", ip)

return ip, ip
}

func (bc *CaptchaProtect) SetIpv4Mask(m int) error {
if m < 8 || m > 32 {
return fmt.Errorf("invalid ipv4 mask: %d. Must be between 8 and 32", m)
}
bc.ipv4Mask = net.CIDRMask(m, 32)

return nil
}

func (bc *CaptchaProtect) SetIpv6Mask(m int) error {
if m < 8 || m > 128 {
return fmt.Errorf("invalid ipv6 mask: %d. Must be between 8 and 128", m)
}
bc.ipv6Mask = net.CIDRMask(m, 128)

return nil
}

func (bc *CaptchaProtect) isGoodBot(req *http.Request, clientIP string) bool {
if bc.config.ProtectParameters == "true" {
if len(req.URL.Query()) > 0 {
Expand Down Expand Up @@ -528,6 +547,10 @@ func IsIpGoodBot(clientIP string, goodBots []string) bool {
return false
}

func (bc *CaptchaProtect) SetExemptIps(exemptIps []*net.IPNet) {
bc.exemptIps = exemptIps
}

func ParseCIDR(cidr string) (*net.IPNet, error) {
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
Expand Down
54 changes: 30 additions & 24 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@ import (
"net"
"net/http/httptest"
"os"
"strings"
"testing"
)

func init() {
log = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
Level: slog.LevelDebug,
}))
}

func TestIsIpGoodBot(t *testing.T) {
// Save the original functions and restore them at the end.
origLookupAddr := lookupAddrFunc
Expand Down Expand Up @@ -179,20 +184,18 @@ func TestParseIp(t *testing.T) {
wantSubnet: "192.168.1.1",
},
{
name: "IPv6 /64",
ip: "2001:0db8:85a3:0000:0000:8a2e:0370:7334",
ipv6Mask: 64,
wantFull: "2001:0db8:85a3:0000:0000:8a2e:0370:7334",
// for /64, we keep 4 hextets
wantSubnet: strings.Join(strings.Split("2001:0db8:85a3:0000:0000:8a2e:0370:7334", ":")[:4], ":"),
name: "IPv6 /64",
ip: "2001:0db8:85a3:0000:0000:8a2e:0370:7334",
ipv6Mask: 64,
wantFull: "2001:0db8:85a3:0000:0000:8a2e:0370:7334",
wantSubnet: "2001:db8:85a3::",
},
{
name: "IPv6 /48",
ip: "2001:0db8:85a3:0000:0000:8a2e:0370:7334",
ipv6Mask: 48,
wantFull: "2001:0db8:85a3:0000:0000:8a2e:0370:7334",
// for /48, we keep 3 hextets
wantSubnet: strings.Join(strings.Split("2001:0db8:85a3:0000:0000:8a2e:0370:7334", ":")[:3], ":"),
name: "IPv6 /48",
ip: "2001:0db8:85a3:0000:0000:8a2e:0370:7334",
ipv6Mask: 48,
wantFull: "2001:0db8:85a3:0000:0000:8a2e:0370:7334",
wantSubnet: "2001:db8:85a3::",
},
{
name: "Invalid IP returns same string",
Expand All @@ -206,7 +209,13 @@ func TestParseIp(t *testing.T) {

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
gotFull, gotSubnet := ParseIp(tc.ip, tc.ipv4Mask, tc.ipv6Mask)
c := CreateConfig()
bc := &CaptchaProtect{
config: c,
ipv4Mask: net.CIDRMask(tc.ipv4Mask, 32),
ipv6Mask: net.CIDRMask(tc.ipv6Mask, 128),
}
gotFull, gotSubnet := bc.ParseIp(tc.ip)
if gotFull != tc.wantFull {
t.Errorf("ParseIp(%q, %d, %d) got full = %q, want %q", tc.ip, tc.ipv4Mask, tc.ipv6Mask, gotFull, tc.wantFull)
}
Expand Down Expand Up @@ -454,10 +463,6 @@ func TestGetClientIP(t *testing.T) {
expectedIP: "5.5.5.5",
},
}
handler := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
Level: slog.LevelDebug,
})
log = slog.New(handler)
for _, tc := range tests {

t.Run(tc.name, func(t *testing.T) {
Expand All @@ -471,14 +476,15 @@ func TestGetClientIP(t *testing.T) {
c.IPForwardedHeader = tc.config.IPForwardedHeader
c.IPDepth = tc.config.IPDepth
exemptIps := tc.exemptIps
for _, ip := range c.ExemptIPs {
_, r := ParseIp(ip, 16, 64)
exemptIps = append(exemptIps, parseCIDR(r, t))
}
bc := &CaptchaProtect{
config: c,
exemptIps: exemptIps,
config: c,
ipv4Mask: net.CIDRMask(16, 32),
ipv6Mask: net.CIDRMask(64, 128),
}
for _, ip := range c.ExemptIPs {
exemptIps = append(exemptIps, parseCIDR(ip, t))
}
bc.exemptIps = exemptIps

ip, _ := bc.getClientIP(req)
if ip != tc.expectedIP {
Expand Down

0 comments on commit 7a58dd2

Please sign in to comment.