Skip to content

Commit

Permalink
Add support for limiting to only returning A / AAAA records
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sullivan committed Feb 4, 2025
1 parent 89876bd commit 6693cbf
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 66 deletions.
54 changes: 27 additions & 27 deletions responder.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,8 @@ func (r *responder) announceAtInterface(service *Service, iface *net.Interface)
answer = append(answer, SRV(*service))
answer = append(answer, PTR(*service))
answer = append(answer, TXT(*service))
for _, a := range A(*service, iface) {
answer = append(answer, a)
}
for _, aaaa := range AAAA(*service, iface) {
answer = append(answer, aaaa)
}
answer = append(answer, aOrAaaaFilter(service, iface)...)

msg := new(dns.Msg)
msg.Answer = answer
msg.Response = true
Expand Down Expand Up @@ -390,13 +386,7 @@ func (r *responder) handleQuestion(q dns.Question, req *Request, srv Service) *d

extra := []dns.RR{SRV(srv), TXT(srv)}

for _, a := range A(srv, req.iface) {
extra = append(extra, a)
}

for _, aaaa := range AAAA(srv, req.iface) {
extra = append(extra, aaaa)
}
extra = append(extra, aOrAaaaFilter(&srv, req.iface)...)

if nsec := NSEC(ptr, srv, req.iface); nsec != nil {
extra = append(extra, nsec)
Expand All @@ -414,13 +404,7 @@ func (r *responder) handleQuestion(q dns.Question, req *Request, srv Service) *d

var extra []dns.RR

for _, a := range A(srv, req.iface) {
extra = append(extra, a)
}

for _, aaaa := range AAAA(srv, req.iface) {
extra = append(extra, aaaa)
}
extra = append(extra, aOrAaaaFilter(&srv, req.iface)...)

if nsec := NSEC(SRV(srv), srv, req.iface); nsec != nil {
extra = append(extra, nsec)
Expand All @@ -436,13 +420,7 @@ func (r *responder) handleQuestion(q dns.Question, req *Request, srv Service) *d
case strings.ToLower(srv.Hostname()):
var answer []dns.RR

for _, a := range A(srv, req.iface) {
answer = append(answer, a)
}

for _, aaaa := range AAAA(srv, req.iface) {
answer = append(answer, aaaa)
}
answer = append(answer, aOrAaaaFilter(&srv, req.iface)...)

resp.Answer = answer

Expand Down Expand Up @@ -520,3 +498,25 @@ func containsConflictingAnswers(req *Request, handle *serviceHandle) bool {

return false
}

func aOrAaaaFilter(service *Service, iface *net.Interface) []dns.RR {
var result []dns.RR
switch service.AdvertiseIPType {
case IPv4:
for _, a := range A(*service, iface) {
result = append(result, a)
}
case IPv6:
for _, aaaa := range AAAA(*service, iface) {
result = append(result, aaaa)
}
default:
for _, a := range A(*service, iface) {
result = append(result, a)
}
for _, aaaa := range AAAA(*service, iface) {
result = append(result, aaaa)
}
}
return result
}
93 changes: 92 additions & 1 deletion responder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package dnssd

import (
"context"
"github.com/brutella/dnssd/log"
"github.com/miekg/dns"
"net"
"testing"
Expand Down Expand Up @@ -49,7 +50,7 @@ func TestRegisterServiceWithExplicitIP(t *testing.T) {
t.Fatal(err)
}
sv.ifaceIPs = map[string][]net.IP{
"lo0": []net.IP{net.IP{192, 168, 0, 123}},
"lo0": {net.IP{192, 168, 0, 123}},
}

conn := newTestConn()
Expand Down Expand Up @@ -101,3 +102,93 @@ func TestRegisterServiceWithExplicitIP(t *testing.T) {
r.Respond(ctx)
})
}

type expectedIP struct {
advType IPType
expected []net.IP
}

func TestRegisterServiceWithSpecifiedAdvertisedIP(t *testing.T) {
log.Debug.Enable()

v4 := net.IP{192, 168, 0, 123}
v6 := net.ParseIP("fe80::1")

var expectedIPs = map[string]expectedIP{
"v4 only": {IPv4, []net.IP{v4}},
"v6 only": {IPv6, []net.IP{v6}},
"both / unspecified": {IPType(0), []net.IP{v4, v6}},
}

for name, expected := range expectedIPs {
t.Run(name, func(t *testing.T) {
cfg := Config{
Host: "Computer",
Name: "Test",
Type: "_asdf._tcp",
Domain: "local",
Port: 12345,
Ifaces: []string{"lo0"},
AdvertiseIPType: expected.advType,
}
sv, err := NewService(cfg)
if err != nil {
t.Fatal(err)
}
sv.ifaceIPs = map[string][]net.IP{
"lo0": {v4, v6},
}

conn := newTestConn()
otherConn := newTestConn()
conn.in = otherConn.out
conn.out = otherConn.in

ctx, cancel := context.WithCancel(context.Background())
t.Run("resolver", func(t *testing.T) {
t.Parallel()

lookupCtx, lookupCancel := context.WithTimeout(ctx, 5*time.Second)

defer lookupCancel()
defer cancel()

srv, err := lookupInstance(lookupCtx, "Test._asdf._tcp.local.", otherConn)
if err != nil {
t.Fatal(err)
}

if is, want := srv.Name, "Test"; is != want {
t.Fatalf("%v != %v", is, want)
}

if is, want := srv.Type, "_asdf._tcp"; is != want {
t.Fatalf("%v != %v", is, want)
}

if is, want := srv.Host, "Computer"; is != want {
t.Fatalf("%v != %v", is, want)
}

ips := srv.IPsAtInterface(&net.Interface{Name: "lo0"})
if is, want := len(ips), len(expected.expected); is != want {
t.Fatalf("%v != %v", is, want)
}

for i, ip := range ips { // this should always be the same order as a records are processed before aaaa records
if is, want := ip, expected.expected[i]; !is.Equal(want) {
t.Fatalf("%v != %v", is, want)
}
}
})

t.Run("responder", func(t *testing.T) {
t.Parallel()

r := newResponder(conn)
r.addManaged(sv) // don't probe
r.Respond(ctx)
})
})
}
}
91 changes: 53 additions & 38 deletions service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package dnssd

import (
"bytes"

"github.com/brutella/dnssd/log"

"fmt"
Expand Down Expand Up @@ -39,18 +38,23 @@ type Config struct {

// Interfaces at which the service should be registered
Ifaces []string

// The addresses for the interface which should be used (A / AAAA / Both)
// If empty, all addresses are used.
AdvertiseIPType IPType
}

func (c Config) Copy() Config {
return Config{
Name: c.Name,
Type: c.Type,
Domain: c.Domain,
Host: c.Host,
Text: c.Text,
IPs: c.IPs,
Port: c.Port,
Ifaces: c.Ifaces,
Name: c.Name,
Type: c.Type,
Domain: c.Domain,
Host: c.Host,
Text: c.Text,
IPs: c.IPs,
Port: c.Port,
Ifaces: c.Ifaces,
AdvertiseIPType: c.AdvertiseIPType,
}
}

Expand Down Expand Up @@ -97,17 +101,26 @@ func validHostname(host string) string {
return result
}

type IPType int

const (
Both = IPType(0)
IPv4 = IPType(4)
IPv6 = IPType(6)
)

// Service represents a DNS-SD service instance
type Service struct {
Name string
Type string
Domain string
Host string
Text map[string]string
TTL time.Duration // Original time to live
Port int
IPs []net.IP
Ifaces []string
Name string
Type string
Domain string
Host string
Text map[string]string
TTL time.Duration // Original time to live
Port int
IPs []net.IP
Ifaces []string
AdvertiseIPType IPType

// stores ips by interface name for caching purposes
ifaceIPs map[string][]net.IP
Expand Down Expand Up @@ -162,15 +175,16 @@ func NewService(cfg Config) (s Service, err error) {
}

return Service{
Name: trimServiceNameSuffixRight(name),
Type: typ,
Domain: domain,
Host: validHostname(host),
Text: text,
Port: port,
IPs: ips,
Ifaces: ifaces,
ifaceIPs: map[string][]net.IP{},
Name: trimServiceNameSuffixRight(name),
Type: typ,
Domain: domain,
Host: validHostname(host),
Text: text,
Port: port,
IPs: ips,
AdvertiseIPType: cfg.AdvertiseIPType,
Ifaces: ifaces,
ifaceIPs: map[string][]net.IP{},
}, nil
}

Expand Down Expand Up @@ -256,17 +270,18 @@ func (s *Service) HasIPOnAnyInterface(ip net.IP) bool {
// Copy returns a copy of the service.
func (s Service) Copy() *Service {
return &Service{
Name: s.Name,
Type: s.Type,
Domain: s.Domain,
Host: s.Host,
Text: s.Text,
TTL: s.TTL,
IPs: s.IPs,
Port: s.Port,
Ifaces: s.Ifaces,
ifaceIPs: s.ifaceIPs,
expiration: s.expiration,
Name: s.Name,
Type: s.Type,
Domain: s.Domain,
Host: s.Host,
Text: s.Text,
TTL: s.TTL,
IPs: s.IPs,
Port: s.Port,
AdvertiseIPType: s.AdvertiseIPType,
Ifaces: s.Ifaces,
ifaceIPs: s.ifaceIPs,
expiration: s.expiration,
}
}

Expand Down

0 comments on commit 6693cbf

Please sign in to comment.