Skip to content

Commit

Permalink
Merge pull request #293 from safing/fix/patch-set-3
Browse files Browse the repository at this point in the history
DNS and other fixes & improvements
  • Loading branch information
dhaavi authored Apr 19, 2021
2 parents 50d10ff + 81fb67b commit 06eee68
Show file tree
Hide file tree
Showing 20 changed files with 297 additions and 221 deletions.
124 changes: 66 additions & 58 deletions firewall/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"net"
"strings"
"time"

"github.com/miekg/dns"
"github.com/safing/portbase/database"
Expand All @@ -16,9 +15,19 @@ import (
"github.com/safing/portmaster/resolver"
)

func filterDNSSection(entries []dns.RR, p *profile.LayeredProfile, resolverScope netutils.IPScope, sysResolver bool) ([]dns.RR, []string, int, string) {
func filterDNSSection(
ctx context.Context,
entries []dns.RR,
p *profile.LayeredProfile,
resolverScope netutils.IPScope,
sysResolver bool,
) ([]dns.RR, []string, int, string) {

// Will be filled 1:1 most of the time.
goodEntries := make([]dns.RR, 0, len(entries))
filteredRecords := make([]string, 0, len(entries))

// Will stay empty most of the time.
var filteredRecords []string

// keeps track of the number of valid and allowed
// A and AAAA records.
Expand All @@ -44,13 +53,16 @@ func filterDNSSection(entries []dns.RR, p *profile.LayeredProfile, resolverScope
switch {
case ipScope.IsLocalhost():
// No DNS should return localhost addresses
filteredRecords = append(filteredRecords, rr.String())
filteredRecords = append(filteredRecords, formatRR(rr))
interveningOptionKey = profile.CfgOptionRemoveOutOfScopeDNSKey
log.Tracer(ctx).Tracef("filter: RR violates resolver scope: %s", formatRR(rr))
continue

case resolverScope.IsGlobal() && ipScope.IsLAN() && !sysResolver:
// No global DNS should return LAN addresses
filteredRecords = append(filteredRecords, rr.String())
filteredRecords = append(filteredRecords, formatRR(rr))
interveningOptionKey = profile.CfgOptionRemoveOutOfScopeDNSKey
log.Tracer(ctx).Tracef("filter: RR violates resolver scope: %s", formatRR(rr))
continue
}
}
Expand All @@ -59,16 +71,21 @@ func filterDNSSection(entries []dns.RR, p *profile.LayeredProfile, resolverScope
// filter by flags
switch {
case p.BlockScopeInternet() && ipScope.IsGlobal():
filteredRecords = append(filteredRecords, rr.String())
filteredRecords = append(filteredRecords, formatRR(rr))
interveningOptionKey = profile.CfgOptionBlockScopeInternetKey
log.Tracer(ctx).Tracef("filter: RR is in blocked scope Internet: %s", formatRR(rr))
continue

case p.BlockScopeLAN() && ipScope.IsLAN():
filteredRecords = append(filteredRecords, rr.String())
filteredRecords = append(filteredRecords, formatRR(rr))
interveningOptionKey = profile.CfgOptionBlockScopeLANKey
log.Tracer(ctx).Tracef("filter: RR is in blocked scope LAN: %s", formatRR(rr))
continue

case p.BlockScopeLocal() && ipScope.IsLocalhost():
filteredRecords = append(filteredRecords, rr.String())
filteredRecords = append(filteredRecords, formatRR(rr))
interveningOptionKey = profile.CfgOptionBlockScopeLocalKey
log.Tracer(ctx).Tracef("filter: RR is in blocked scope Localhost: %s", formatRR(rr))
continue
}

Expand All @@ -83,9 +100,13 @@ func filterDNSSection(entries []dns.RR, p *profile.LayeredProfile, resolverScope
return goodEntries, filteredRecords, allowedAddressRecords, interveningOptionKey
}

func filterDNSResponse(conn *network.Connection, rrCache *resolver.RRCache, sysResolver bool) *resolver.RRCache {
p := conn.Process().Profile()

func filterDNSResponse(
ctx context.Context,
conn *network.Connection,
p *profile.LayeredProfile,
rrCache *resolver.RRCache,
sysResolver bool,
) *resolver.RRCache {
// do not modify own queries
if conn.Process().Pid == ownPID {
return rrCache
Expand All @@ -96,20 +117,20 @@ func filterDNSResponse(conn *network.Connection, rrCache *resolver.RRCache, sysR
return rrCache
}

// duplicate entry
rrCache = rrCache.ShallowCopy()
rrCache.FilteredEntries = make([]string, 0)

var filteredRecords []string
var validIPs int
var interveningOptionKey string

rrCache.Answer, filteredRecords, validIPs, interveningOptionKey = filterDNSSection(rrCache.Answer, p, rrCache.Resolver.IPScope, sysResolver)
rrCache.FilteredEntries = append(rrCache.FilteredEntries, filteredRecords...)
rrCache.Answer, filteredRecords, validIPs, interveningOptionKey = filterDNSSection(ctx, rrCache.Answer, p, rrCache.Resolver.IPScope, sysResolver)
if len(filteredRecords) > 0 {
rrCache.FilteredEntries = append(rrCache.FilteredEntries, filteredRecords...)
}

// we don't count the valid IPs in the extra section
rrCache.Extra, filteredRecords, _, _ = filterDNSSection(rrCache.Extra, p, rrCache.Resolver.IPScope, sysResolver)
rrCache.FilteredEntries = append(rrCache.FilteredEntries, filteredRecords...)
// Don't count the valid IPs in the extra section.
rrCache.Extra, filteredRecords, _, _ = filterDNSSection(ctx, rrCache.Extra, p, rrCache.Resolver.IPScope, sysResolver)
if len(filteredRecords) > 0 {
rrCache.FilteredEntries = append(rrCache.FilteredEntries, filteredRecords...)
}

if len(rrCache.FilteredEntries) > 0 {
rrCache.Filtered = true
Expand All @@ -127,34 +148,8 @@ func filterDNSResponse(conn *network.Connection, rrCache *resolver.RRCache, sysR
conn.Block("DNS response only contained to-be-blocked IPs", interveningOptionKey)
}

// If all entries are filtered, this could mean that these are broken/bogus resource records.
if rrCache.Expired() {
// If the entry is expired, force delete it.
err := resolver.ResetCachedRecord(rrCache.Domain, rrCache.Question.String())
if err != nil && err != database.ErrNotFound {
log.Warningf(
"filter: failed to delete fully filtered name cache for %s: %s",
rrCache.ID(),
err,
)
}
} else if rrCache.Expires > time.Now().Add(10*time.Second).Unix() {
// Set a low TTL of 10 seconds if TTL is higher than that.
rrCache.Expires = time.Now().Add(10 * time.Second).Unix()
err := rrCache.Save()
if err != nil {
log.Debugf(
"filter: failed to set shorter TTL on fully filtered name cache for %s: %s",
rrCache.ID(),
err,
)
}
}

return nil
return rrCache
}

log.Infof("filter: filtered DNS replies for %s: %s", conn, strings.Join(rrCache.FilteredEntries, ", "))
}

return rrCache
Expand All @@ -168,43 +163,51 @@ func FilterResolvedDNS(
q *resolver.Query,
rrCache *resolver.RRCache,
) *resolver.RRCache {
// Check if we have a process and profile.
layeredProfile := conn.Process().Profile()
if layeredProfile == nil {
log.Tracer(ctx).Warning("unknown process or profile")
return nil
}

// special grant for connectivity domains
if checkConnectivityDomain(ctx, conn, nil) {
if checkConnectivityDomain(ctx, conn, layeredProfile, nil) {
// returns true if check triggered
return rrCache
}

// Only filter criticial things if request comes from the system resolver.
sysResolver := conn.Process().IsSystemResolver()

updatedRR := filterDNSResponse(conn, rrCache, sysResolver)
if updatedRR == nil {
return nil
// Filter dns records and return if the query is blocked.
rrCache = filterDNSResponse(ctx, conn, layeredProfile, rrCache, sysResolver)
if conn.Verdict == network.VerdictBlock {
return rrCache
}

if !sysResolver && mayBlockCNAMEs(ctx, conn) {
return nil
// Block by CNAMEs.
if !sysResolver {
mayBlockCNAMEs(ctx, conn, layeredProfile)
}

return updatedRR
return rrCache
}

func mayBlockCNAMEs(ctx context.Context, conn *network.Connection) bool {
func mayBlockCNAMEs(ctx context.Context, conn *network.Connection, p *profile.LayeredProfile) bool {
// if we have CNAMEs and the profile is configured to filter them
// we need to re-check the lists and endpoints here
if conn.Process().Profile().FilterCNAMEs() {
if p.FilterCNAMEs() {
conn.Entity.ResetLists()
conn.Entity.EnableCNAMECheck(ctx, true)

result, reason := conn.Process().Profile().MatchEndpoint(ctx, conn.Entity)
result, reason := p.MatchEndpoint(ctx, conn.Entity)
if result == endpoints.Denied {
conn.BlockWithContext(reason.String(), profile.CfgOptionFilterCNAMEKey, reason.Context())
return true
}

if result == endpoints.NoMatch {
result, reason = conn.Process().Profile().MatchFilterLists(ctx, conn.Entity)
result, reason = p.MatchFilterLists(ctx, conn.Entity)
if result == endpoints.Denied {
conn.BlockWithContext(reason.String(), profile.CfgOptionFilterCNAMEKey, reason.Context())
return true
Expand Down Expand Up @@ -304,3 +307,8 @@ func UpdateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *netw
}
}
}

// formatRR is a friendlier alternative to miekg/dns.RR.String().
func formatRR(rr dns.RR) string {
return strings.ReplaceAll(rr.String(), "\t", " ")
}
19 changes: 9 additions & 10 deletions firewall/interception.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,15 +144,14 @@ func getConnection(pkt packet.Packet) (*network.Connection, error) {

// Transform and log result.
conn := newConn.(*network.Connection)
switch {
case created && shared:
log.Tracer(pkt.Ctx()).Tracef("filter: created new connection %s (shared)", conn.ID)
case created:
log.Tracer(pkt.Ctx()).Tracef("filter: created new connection %s", conn.ID)
case shared:
log.Tracer(pkt.Ctx()).Tracef("filter: assigned connection %s (shared)", conn.ID)
default:
log.Tracer(pkt.Ctx()).Tracef("filter: assigned connection %s", conn.ID)
sharedIndicator := ""
if shared {
sharedIndicator = " (shared)"
}
if created {
log.Tracer(pkt.Ctx()).Tracef("filter: created new connection %s%s", conn.ID, sharedIndicator)
} else {
log.Tracer(pkt.Ctx()).Tracef("filter: assigned connection %s%s", conn.ID, sharedIndicator)
}

return conn, nil
Expand Down Expand Up @@ -307,7 +306,7 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) {
log.Tracer(pkt.Ctx()).Trace("filter: handing over to connection-based handler")

// Check for pre-authenticated port.
if localPortIsPreAuthenticated(conn.Entity.Protocol, conn.LocalPort) {
if !conn.Inbound && localPortIsPreAuthenticated(conn.Entity.Protocol, conn.LocalPort) {
// Approve connection.
conn.Accept("connection by Portmaster", noReasonOptionKey)
conn.Internal = true
Expand Down
2 changes: 1 addition & 1 deletion firewall/interception/nfq/nfq.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func (q *Queue) packetHandler(ctx context.Context) func(nfqueue.Attribute) int {

if attrs.Payload == nil {
// There is not payload.
log.Warningf("nfqueue: packet #%s has no payload", pkt.pktID)
log.Warningf("nfqueue: packet #%d has no payload", pkt.pktID)
return 0
}

Expand Down
Loading

0 comments on commit 06eee68

Please sign in to comment.