diff --git a/cmd/sniproxy/config.defaults.yaml b/cmd/sniproxy/config.defaults.yaml index bb2eb23..5a5cbed 100644 --- a/cmd/sniproxy/config.defaults.yaml +++ b/cmd/sniproxy/config.defaults.yaml @@ -6,6 +6,7 @@ # note that there are 2 underscores between general and BIND_DNS_OVER_UDP general: # Upsteam DNS URI. examples: Upstream DNS URI. examples: udp://1.1.1.1:53, tcp://1.1.1.1:53, tcp-tls://1.1.1.1:853, https://dns.google/dns-query + # NOTE: if you're using SOCKS, avoid using UDP for upstream DNS upstream_dns: udp://8.8.8.8:53 # enable send DNS through socks5 upstream_dns_over_socks5: false @@ -25,8 +26,12 @@ general: tls_key: # HTTP Port to listen on. Should remain 80 in most cases. use :80 to listen on both IPv4 and IPv6 bind_http: "0.0.0.0:80" + # bind additional ports for HTTP. a list of portsor ranges separated by commas. example: "8080,8081-8083". follows the same listen address as bind_http + bind_http_additional: ["8080","8081-8083"] # HTTPS Port to listen on. Should remain 443 in most cases bind_https: "0.0.0.0:443" + # bind additional ports for HTTPS. a list of portsor ranges separated by commas. example: "8443,8444-8446". follows the same listen address as bind_https + bind_https_additional: ["8443","8444-8446"] # Enable prometheus endpoint on IP:PORT. example: 127.0.0.1:8080. Always exposes /metrics and only supports HTTP bind_prometheus: # Interface used for outbound TLS connections. uses OS prefered one if empty diff --git a/cmd/sniproxy/main.go b/cmd/sniproxy/main.go index ef19c8a..8108aa7 100644 --- a/cmd/sniproxy/main.go +++ b/cmd/sniproxy/main.go @@ -126,7 +126,9 @@ func main() { c.TLSCert = generalConfig.String("tls_cert") c.TLSKey = generalConfig.String("tls_key") c.BindHTTP = generalConfig.String("bind_http") + c.BindHTTPAdditional = generalConfig.Strings("bind_http_additional") c.BindHTTPS = generalConfig.String("bind_https") + c.BindHTTPSAdditional = generalConfig.Strings("bind_https_additional") c.Interface = generalConfig.String("interface") c.PreferredVersion = generalConfig.String("preferred_version") @@ -242,8 +244,24 @@ func main() { return } - go sniproxy.RunHTTP(&c, logger.With().Str("service", "http").Logger()) - go sniproxy.RunHTTPS(&c, logger.With().Str("service", "https").Logger()) + // get a list of http and https binds + if err := c.SetBindHTTPListeners(logger); err != nil { + logger.Error().Msgf("error setting up HTTP listeners: %v", err) + return + } + logger.Info().Msgf("HTTP listeners: %v", c.BindHTTPListeners) + if err := c.SetBindHTTPSListeners(logger); err != nil { + logger.Error().Msgf("error setting up HTTPS listeners: %v", err) + return + } + logger.Info().Msgf("HTTPS listeners: %v", c.BindHTTPSListeners) + + for _, addr := range c.BindHTTPListeners { + go sniproxy.RunHTTP(&c, addr, logger.With().Str("service", "http").Str("listener", addr).Logger()) + } + for _, addr := range c.BindHTTPSListeners { + go sniproxy.RunHTTPS(&c, addr, logger.With().Str("service", "https").Str("listener", addr).Logger()) + } go sniproxy.RunDNS(&c, logger.With().Str("service", "dns").Logger()) // wait forever. TODO: add signal handling here diff --git a/go.mod b/go.mod index 311286a..78e4d75 100644 --- a/go.mod +++ b/go.mod @@ -53,7 +53,7 @@ require ( github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.59.1 // indirect github.com/prometheus/procfs v0.15.1 // indirect - github.com/quic-go/qpack v0.5.0 // indirect + github.com/quic-go/qpack v0.5.1 // indirect github.com/quic-go/quic-go v0.46.0 // indirect github.com/redis/go-redis/v9 v9.6.1 // indirect github.com/sirupsen/logrus v1.9.3 // indirect diff --git a/go.sum b/go.sum index 0e6311b..9cc2278 100644 --- a/go.sum +++ b/go.sum @@ -303,8 +303,8 @@ github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4O github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= -github.com/quic-go/qpack v0.5.0 h1:jldbr38Ef/swDfxtvNvvUIYNg5LNm3Oa9W+IZvCm4q0= -github.com/quic-go/qpack v0.5.0/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= +github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= +github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= github.com/quic-go/quic-go v0.46.0 h1:uuwLClEEyk1DNvchH8uCByQVjo3yKL9opKulExNDs7Y= github.com/quic-go/quic-go v0.46.0/go.mod h1:1dLehS7TIR64+vxGR70GDcatWTOtMX2PUtnKsjbTurI= github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= diff --git a/pkg/conf.go b/pkg/conf.go index 265b42e..3369e75 100644 --- a/pkg/conf.go +++ b/pkg/conf.go @@ -4,6 +4,8 @@ import ( "fmt" "net/netip" "net/url" + "strconv" + "strings" "github.com/mosajjal/sniproxy/v2/pkg/acl" "github.com/rcrowley/go-metrics" @@ -13,22 +15,26 @@ import ( ) type Config struct { - PublicIPv4 string `yaml:"public_ipv4"` - PublicIPv6 string `yaml:"public_ipv6"` - UpstreamDNS string `yaml:"upstream_dns"` - UpstreamDNSOverSocks5 bool `yaml:"upstream_dns_over_socks5"` - UpstreamSOCKS5 string `yaml:"upstream_socks5"` - BindDNSOverUDP string `yaml:"bind_dns_over_udp"` - BindDNSOverTCP string `yaml:"bind_dns_over_tcp"` - BindDNSOverTLS string `yaml:"bind_dns_over_tls"` - BindDNSOverQuic string `yaml:"bind_dns_over_quic"` - TLSCert string `yaml:"tls_cert"` - TLSKey string `yaml:"tls_key"` - BindHTTP string `yaml:"bind_http"` - BindHTTPS string `yaml:"bind_https"` - Interface string `yaml:"interface"` - BindPrometheus string `yaml:"bind_prometheus"` - AllowConnToLocal bool `yaml:"allow_conn_to_local"` + PublicIPv4 string `yaml:"public_ipv4"` + PublicIPv6 string `yaml:"public_ipv6"` + UpstreamDNS string `yaml:"upstream_dns"` + UpstreamDNSOverSocks5 bool `yaml:"upstream_dns_over_socks5"` + UpstreamSOCKS5 string `yaml:"upstream_socks5"` + BindDNSOverUDP string `yaml:"bind_dns_over_udp"` + BindDNSOverTCP string `yaml:"bind_dns_over_tcp"` + BindDNSOverTLS string `yaml:"bind_dns_over_tls"` + BindDNSOverQuic string `yaml:"bind_dns_over_quic"` + TLSCert string `yaml:"tls_cert"` + TLSKey string `yaml:"tls_key"` + BindHTTP string `yaml:"bind_http"` + BindHTTPAdditional []string `yaml:"bind_http_additional"` + BindHTTPListeners []string `yaml:"-"` // compiled list of bind_http and bind_http_additional listen addresses + BindHTTPS string `yaml:"bind_https"` + BindHTTPSAdditional []string `yaml:"bind_https_additional"` + BindHTTPSListeners []string `yaml:"-"` // compiled list of bind_https and bind_https_additional listen addresses + Interface string `yaml:"interface"` + BindPrometheus string `yaml:"bind_prometheus"` + AllowConnToLocal bool `yaml:"allow_conn_to_local"` Acl []acl.ACL `yaml:"-"` @@ -87,21 +93,98 @@ func (c *Config) SetDNSClient(logger zerolog.Logger) error { var dnsProxy string var dnsClient *DNSClient // if upstream socks5 is not provided or upstream dns over socks5 is disabled, disable socks5 for dns - if c.UpstreamSOCKS5 == "" || !c.UpstreamDNSOverSocks5 { + if c.UpstreamSOCKS5 != "" && !c.UpstreamDNSOverSocks5 { logger.Debug().Msg("disabling socks5 for dns because either upstream socks5 is not provided or upstream dns over socks5 is disabled") dnsProxy = "" } else { dnsProxy = c.UpstreamSOCKS5 - var err error - dnsClient, err = NewDNSClient(c, c.UpstreamDNS, true, dnsProxy) + } + var err error + dnsClient, err = NewDNSClient(c, c.UpstreamDNS, true, dnsProxy) + if err != nil { + logger.Error().Msgf("error setting up dns client with socks5 proxy, falling back to direct DNS client: %v", err) + dnsClient, err = NewDNSClient(c, c.UpstreamDNS, false, "") if err != nil { - logger.Error().Msgf("error setting up dns client with socks5 proxy, falling back to direct DNS client: %v", err) - dnsClient, err = NewDNSClient(c, c.UpstreamDNS, false, "") + return fmt.Errorf("error setting up dns client: %v", err) + } + } + c.DnsClient = *dnsClient + return nil +} + +// parseRanges parses a range of ports or a single port. It returns a list of ports +func parseRanges(portRange ...string) ([]int, error) { + var ports []int + + for _, portRange := range portRange { + + if strings.Index(portRange, "-") == -1 { + port, err := strconv.Atoi(portRange) + if err != nil { + return nil, fmt.Errorf("error parsing port: %v", err) + } + ports = append(ports, port) + } else { + num1Str := strings.Split(portRange, "-")[0] + num2Str := strings.Split(portRange, "-")[1] + // convert both numbers to integers + + num1, err := strconv.Atoi(num1Str) if err != nil { - return fmt.Errorf("error setting up dns client: %v", err) + return nil, fmt.Errorf("error parsing port range: %v", err) + } + num2, err := strconv.Atoi(num2Str) + if err != nil { + return nil, fmt.Errorf("error parsing port range: %v", err) + } + for i := num1; i <= num2; i++ { + ports = append(ports, i) } } } - c.DnsClient = *dnsClient + return ports, nil +} + +// parseBinders parses a bind address and a list of additional ports or port ranges +func parseBinders(bind string, additional []string) ([]string, error) { + // get the bind address from bind + bindAddPort, err := netip.ParseAddrPort(bind) + if err != nil { + return nil, fmt.Errorf("error parsing bind address: %v", err) + } + bindAddresses := []string{bindAddPort.String()} + + // now all the ranges must be parsed, and each of them converted into a bind address and added to the list + portRange, err := parseRanges(additional...) + if err != nil { + return nil, fmt.Errorf("error parsing bind address range: %v", err) + } + for _, port := range portRange { + bindAddresses = append(bindAddresses, fmt.Sprintf("%s:%d", bindAddPort.Addr(), port)) + } + return bindAddresses, nil +} + +// SetBindHTTPListeners sets up a list of bind addresses for HTTP +// it gets the bind address from bind_http as 0.0.0.0:80 format +// and the additional bind addresses from bind_http_additional as a list of ports or port ranges +// such as 8080, 8081-8083, 8085 +// when this function is called, it will compile the list of bind addresses and store it in BindHTTPListeners +func (c *Config) SetBindHTTPListeners(logger zerolog.Logger) error { + bindAddresses, err := parseBinders(c.BindHTTP, c.BindHTTPAdditional) + if err != nil { + return fmt.Errorf("error parsing bind addresses for HTTP: %v", err) + } + c.BindHTTPListeners = bindAddresses + return nil +} + +// SetBindHTTPSListeners sets up a list of bind addresses for HTTPS +func (c *Config) SetBindHTTPSListeners(logger zerolog.Logger) error { + bindAddresses, err := parseBinders(c.BindHTTPS, c.BindHTTPSAdditional) + if err != nil { + return fmt.Errorf("error parsing bind addresses for HTTPS: %v", err) + } + c.BindHTTPSListeners = bindAddresses return nil } diff --git a/pkg/dns.go b/pkg/dns.go index 51c5cbd..a7c821a 100644 --- a/pkg/dns.go +++ b/pkg/dns.go @@ -9,6 +9,7 @@ import ( "net/url" "strings" "sync" + "time" rdns "github.com/folbricht/routedns" @@ -148,7 +149,7 @@ func (dnsc DNSClient) lookupDomain(domain string, version string) (netip.Addr, e return dnsc.lookupDomain4(domain) case "ipv6only": return dnsc.lookupDomain6(domain) - case "ipv4", "4", "0": + case "ipv4", "4", "0", "": // try with ipv4, if there's any error, try with ipv6 ip, err := dnsc.lookupDomain4(domain) if err != nil { @@ -186,6 +187,7 @@ func (dnsc DNSClient) lookupDomain4(domain string) (netip.Addr, error) { } return netip.IPv4Unspecified(), fmt.Errorf("[DNS] Unknown type %s", dns.TypeToString[rAddrDNS[0].Header().Rrtype]) } + func (dnsc DNSClient) lookupDomain6(domain string) (netip.Addr, error) { if !strings.HasSuffix(domain, ".") { domain = domain + "." @@ -323,12 +325,13 @@ func getDialerFromProxyURL(proxyURL *url.URL) (*rdns.Dialer, error) { dialer = &net.Dialer{} if proxyURL != nil && proxyURL.Host != "" { // create a net dialer with proxy - var auth *proxy.Auth + auth := new(proxy.Auth) if proxyURL.User != nil { - auth = new(proxy.Auth) auth.User = proxyURL.User.Username() if p, ok := proxyURL.User.Password(); ok { auth.Password = p + } else { + auth.Password = "" } } c, err := socks5.NewClient(proxyURL.Host, auth.User, auth.Password, 0, 5) // 0 and 5 are borrowed from routedns pr @@ -378,15 +381,16 @@ func NewDNSClient(C *Config, uri string, skipVerify bool, proxy string) (*DNSCli var ldarr net.IP if parsedURL.Scheme == "udp6" { - ldarr = C.pickSrcAddr(6) + ldarr = C.pickSrcAddr("ipv6only") } else { - ldarr = C.pickSrcAddr(4) + ldarr = C.pickSrcAddr("ipv4only") } opt := rdns.DNSClientOptions{ - LocalAddr: ldarr, - UDPSize: 1300, - Dialer: *dialer, + LocalAddr: ldarr, + UDPSize: 1300, + Dialer: *dialer, + QueryTimeout: 10 * time.Second, //TODO: make this configurable } id, err := rdns.NewDNSClient("id", Address, "udp", opt) if err != nil { @@ -403,9 +407,9 @@ func NewDNSClient(C *Config, uri string, skipVerify bool, proxy string) (*DNSCli var ldarr net.IP if parsedURL.Scheme == "tcp6" { - ldarr = C.pickSrcAddr(6) + ldarr = C.pickSrcAddr("ipv6only") } else { - ldarr = C.pickSrcAddr(4) + ldarr = C.pickSrcAddr("ipv4only") } Address := rdns.AddressWithDefault(host, port) @@ -427,10 +431,10 @@ func NewDNSClient(C *Config, uri string, skipVerify bool, proxy string) (*DNSCli var ldarr net.IP bootstrapAddr := "1.1.1.1" if parsedURL.Scheme == "tls6" || parsedURL.Scheme == "tcp-tls6" { - ldarr = C.pickSrcAddr(6) + ldarr = C.pickSrcAddr("ipv6only") bootstrapAddr = "2606:4700:4700::1111" } else { - ldarr = C.pickSrcAddr(4) + ldarr = C.pickSrcAddr("ipv4only") } opt := rdns.DoTClientOptions{ @@ -456,7 +460,7 @@ func NewDNSClient(C *Config, uri string, skipVerify bool, proxy string) (*DNSCli TLSConfig: tlsConfig, BootstrapAddr: "1.1.1.1", //TODO: make this configurable Transport: transport, - LocalAddr: C.pickSrcAddr(4), //TODO:support IPv6 + LocalAddr: C.pickSrcAddr("ipv4only"), //TODO:support IPv6 Dialer: *dialer, } id, err := rdns.NewDoHClient("id", parsedURL.String(), opt) @@ -473,7 +477,7 @@ func NewDNSClient(C *Config, uri string, skipVerify bool, proxy string) (*DNSCli opt := rdns.DoQClientOptions{ TLSConfig: tlsConfig, - LocalAddr: C.pickSrcAddr(4), //TODO:support IPv6 + LocalAddr: C.pickSrcAddr("ipv4only"), //TODO:support IPv6 // Dialer: *dialer, // BUG: not yet supported } id, err := rdns.NewDoQClient("id", parsedURL.Host, opt) diff --git a/pkg/dns_test.go b/pkg/dns_test.go index d25f543..d80b8ab 100644 --- a/pkg/dns_test.go +++ b/pkg/dns_test.go @@ -17,11 +17,11 @@ func TestDNSClient_lookupDomain4(t *testing.T) { client *DNSClient name string domain string - want net.IP + want []net.IP wantErr bool }{ - {client: dnsc, name: "test1", domain: "ident.me", want: net.IPv4(49, 12, 234, 183), wantErr: false}, - {client: dnsc, name: "test1", domain: "ifconfig.me", want: net.IPv4(34, 117, 118, 44), wantErr: false}, + {client: dnsc, name: "test1", domain: "ident.me", want: []net.IP{net.IPv4(49, 12, 234, 183)}, wantErr: false}, + {client: dnsc, name: "test2", domain: "one.one.one.one", want: []net.IP{net.IPv4(1, 1, 1, 1), net.IPv4(1, 0, 0, 1)}, wantErr: false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -31,8 +31,15 @@ func TestDNSClient_lookupDomain4(t *testing.T) { t.Errorf("DNSClient.lookupDomain4() error = %v, wantErr %v", err, tt.wantErr) return } - if !got.Equal(tt.want) { - + // check if the returned IP is in the list of expected IPs + found := false + for _, w := range tt.want { + if got.Equal(w) { + found = true + break + } + } + if !found { t.Errorf("DNSClient.lookupDomain4() = %v, want %v", got, tt.want) } }) diff --git a/pkg/httpproxy.go b/pkg/httpproxy.go index 41a1dfc..c2cdb56 100644 --- a/pkg/httpproxy.go +++ b/pkg/httpproxy.go @@ -38,20 +38,22 @@ var passthruResponseHeaderKeys = [...]string{ "Vary", } -func RunHTTP(c *Config, l zerolog.Logger) { +// RunHTTP starts the HTTP server on the configured bind. bind format is 0.0.0.0:80 or similar +func RunHTTP(c *Config, bind string, l zerolog.Logger) { httplog = l - handler := http.DefaultServeMux + handler := http.NewServeMux() handler.HandleFunc("/", handle80(c)) s := &http.Server{ - Addr: c.BindHTTP, + Addr: bind, Handler: handler, ReadTimeout: 10 * time.Second, WriteTimeout: 10 * time.Second, MaxHeaderBytes: 1 << 20, } + httplog.Info().Str("bind", bind).Msg("starting http server") if err := s.ListenAndServe(); err != nil { httplog.Error().Msg(err.Error()) panic(-1) diff --git a/pkg/https.go b/pkg/https.go index f7b6f2e..9316245 100644 --- a/pkg/https.go +++ b/pkg/https.go @@ -165,11 +165,12 @@ func getPortFromConn(conn net.Conn) int { return portnum } -func RunHTTPS(c *Config, log zerolog.Logger) { - if l, err := net.Listen("tcp", c.BindHTTPS); err != nil { +func RunHTTPS(c *Config, bind string, log zerolog.Logger) { + if l, err := net.Listen("tcp", bind); err != nil { log.Error().Msg(err.Error()) panic(-1) } else { + log.Info().Msgf("listening https on %s", bind) defer l.Close() for { if con, err := l.Accept(); err != nil { diff --git a/pkg/publicip.go b/pkg/publicip.go index 791b123..4cc3001 100644 --- a/pkg/publicip.go +++ b/pkg/publicip.go @@ -26,7 +26,7 @@ func GetPublicIPv4() (string, error) { } externalIP := "" // trying to get the public IP from multiple sources to see if they match. - resp, err := http.Get("https://myexternalip.com/raw") + resp, err := http.Get("https://4.ident.me") if err == nil { defer resp.Body.Close() body, err := io.ReadAll(resp.Body)