From ccdd03aa661e3efbfcd8b09a01114077bfb41d69 Mon Sep 17 00:00:00 2001 From: devStorm <59678453+developStorm@users.noreply.github.com> Date: Fri, 1 Mar 2024 20:44:23 -0800 Subject: [PATCH] refactor: AddDefaultPortToDNSServerName the refactored function relies on built-in net package to perform more rigorous validation on input nameserver address --- internal/util/util.go | 33 ++++++++++++++++++++------------- pkg/zdns/lookup.go | 11 +++++++++-- pkg/zdns/zdns.go | 6 +++++- 3 files changed, 34 insertions(+), 16 deletions(-) diff --git a/internal/util/util.go b/internal/util/util.go index 7704c761..9167e3e6 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -16,7 +16,7 @@ package util import ( "fmt" - "regexp" + "net" "strings" "github.com/spf13/cobra" @@ -24,19 +24,28 @@ import ( "github.com/spf13/viper" ) -var rePort *regexp.Regexp -var reV6 *regexp.Regexp - const EnvPrefix = "ZDNS" -func AddDefaultPortToDNSServerName(s string) string { - if !rePort.MatchString(s) { - return s + ":53" - } else if reV6.MatchString(s) { - return "[" + s + "]:53" - } else { - return s +func AddDefaultPortToDNSServerName(inAddr string) (string, error) { + // Try to split host and port to see if the port is already specified. + host, port, err := net.SplitHostPort(inAddr) + if err != nil { + // might mean there's no port specified + host = inAddr + } + + // Validate the host part as an IP address. + ip := net.ParseIP(host) + if ip == nil { + return "", fmt.Errorf("invalid IP address") } + + // If the original input does not have a port, specify port 53 + if port == "" { + port = "53" + } + + return net.JoinHostPort(ip.String(), port), nil } // Reference: https://github.com/carolynvs/stingoftheviper/blob/main/main.go @@ -65,6 +74,4 @@ func GetDefaultResolvers() []string { } func init() { - rePort = regexp.MustCompile(":\\d+$") // string ends with potential port number - reV6 = regexp.MustCompile("^([0-9a-f]*:)") // string starts like valid IPv6 address } diff --git a/pkg/zdns/lookup.go b/pkg/zdns/lookup.go index 4a348b0c..486a057a 100644 --- a/pkg/zdns/lookup.go +++ b/pkg/zdns/lookup.go @@ -77,7 +77,11 @@ func parseNormalInputLine(line string) (string, string) { if len(s) == 1 { return s[0], "" } else { - return s[0], util.AddDefaultPortToDNSServerName(s[1]) + ns, err := util.AddDefaultPortToDNSServerName(s[1]) + if err != nil { + log.Fatal("Unable to parse nameserver: ", err) + } + return s[0], ns } } @@ -124,7 +128,10 @@ func doLookup(g GlobalLookupFactory, gc *GlobalConf, input <-chan interface{}, o rawName, entryMetadata = parseMetadataInputLine(line) res.Metadata = entryMetadata } else if gc.NameServerMode { - nameServer = util.AddDefaultPortToDNSServerName(line) + nameServer, err = util.AddDefaultPortToDNSServerName(line) + if err != nil { + log.Fatal("Unable to parse nameserver: ", err) + } } else { rawName, nameServer = parseNormalInputLine(line) } diff --git a/pkg/zdns/zdns.go b/pkg/zdns/zdns.go index 3a8a960e..94ae8ce1 100644 --- a/pkg/zdns/zdns.go +++ b/pkg/zdns/zdns.go @@ -131,7 +131,11 @@ func Run(gc GlobalConf, flags *pflag.FlagSet, ns = strings.Split(*servers_string, ",") } for i, s := range ns { - ns[i] = util.AddDefaultPortToDNSServerName(s) + nsWithPort, err := util.AddDefaultPortToDNSServerName(s) + if err != nil { + log.Fatal("Unable to parse nameserver: ", err) + } + ns[i] = nsWithPort } gc.NameServers = ns gc.NameServersSpecified = true