Skip to content

Commit

Permalink
[client] Enable userspace forwarder conditionally (#3309)
Browse files Browse the repository at this point in the history
* Enable userspace forwarder conditionally

* Move disable/enable logic
  • Loading branch information
lixmal authored Feb 12, 2025
1 parent 18f84f0 commit b41de7f
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 49 deletions.
8 changes: 8 additions & 0 deletions client/firewall/iptables/manager_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,14 @@ func (m *Manager) SetLogLevel(log.Level) {
// not supported
}

func (m *Manager) EnableRouting() error {
return nil
}

func (m *Manager) DisableRouting() error {
return nil
}

func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}
4 changes: 4 additions & 0 deletions client/firewall/manager/firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ type Manager interface {
Flush() error

SetLogLevel(log.Level)

EnableRouting() error

DisableRouting() error
}

func GenKey(format string, pair RouterPair) string {
Expand Down
8 changes: 8 additions & 0 deletions client/firewall/nftables/manager_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,14 @@ func (m *Manager) SetLogLevel(log.Level) {
// not supported
}

func (m *Manager) EnableRouting() error {
return nil
}

func (m *Manager) DisableRouting() error {
return nil
}

// Flush rule/chain/set operations from the buffer
//
// Method also get all rules after flush and refreshes handle values in the rulesets
Expand Down
144 changes: 104 additions & 40 deletions client/firewall/uspfilter/uspfilter.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ type Manager struct {

mutex sync.RWMutex

// indicates whether server routes are disabled
disableServerRoutes bool
// indicates whether we forward packets not destined for ourselves
routingEnabled bool
// indicates whether we leave forwarding and filtering to the native firewall
Expand Down Expand Up @@ -125,16 +127,28 @@ func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.
return mgr, nil
}

func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
disableConntrack, err := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
if err != nil {
log.Warnf("failed to parse %s: %v", EnvDisableConntrack, err)
func parseCreateEnv() (bool, bool) {
var disableConntrack, enableLocalForwarding bool
var err error
if val := os.Getenv(EnvDisableConntrack); val != "" {
disableConntrack, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvDisableConntrack, err)
}
}
enableLocalForwarding, err := strconv.ParseBool(os.Getenv(EnvEnableNetstackLocalForwarding))
if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
if val := os.Getenv(EnvEnableNetstackLocalForwarding); val != "" {
enableLocalForwarding, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
}
}

return disableConntrack, enableLocalForwarding
}

func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
disableConntrack, enableLocalForwarding := parseCreateEnv()

m := &Manager{
decoders: sync.Pool{
New: func() any {
Expand All @@ -149,15 +163,16 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
return d
},
},
nativeFirewall: nativeFirewall,
outgoingRules: make(map[string]RuleSet),
incomingRules: make(map[string]RuleSet),
wgIface: iface,
localipmanager: newLocalIPManager(),
routingEnabled: false,
stateful: !disableConntrack,
logger: nblog.NewFromLogrus(log.StandardLogger()),
netstack: netstack.IsEnabled(),
nativeFirewall: nativeFirewall,
outgoingRules: make(map[string]RuleSet),
incomingRules: make(map[string]RuleSet),
wgIface: iface,
localipmanager: newLocalIPManager(),
disableServerRoutes: disableServerRoutes,
routingEnabled: false,
stateful: !disableConntrack,
logger: nblog.NewFromLogrus(log.StandardLogger()),
netstack: netstack.IsEnabled(),
// default true for non-netstack, for netstack only if explicitly enabled
localForwarding: !netstack.IsEnabled() || enableLocalForwarding,
}
Expand All @@ -166,7 +181,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
return nil, fmt.Errorf("update local IPs: %w", err)
}

// Only initialize trackers if stateful mode is enabled
if disableConntrack {
log.Info("conntrack is disabled")
} else {
Expand All @@ -175,7 +189,12 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
}

m.determineRouting(iface, disableServerRoutes)
// netstack needs the forwarder for local traffic
if m.netstack && m.localForwarding {
if err := m.initForwarder(); err != nil {
log.Errorf("failed to initialize forwarder: %v", err)
}
}

if err := m.blockInvalidRouted(iface); err != nil {
log.Errorf("failed to block invalid routed traffic: %v", err)
Expand Down Expand Up @@ -213,17 +232,29 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
return nil
}

func (m *Manager) determineRouting(iface common.IFaceMapper, disableServerRoutes bool) {
disableUspRouting, _ := strconv.ParseBool(os.Getenv(EnvDisableUserspaceRouting))
forceUserspaceRouter, _ := strconv.ParseBool(os.Getenv(EnvForceUserspaceRouter))
func (m *Manager) determineRouting() error {
var disableUspRouting, forceUserspaceRouter bool
var err error
if val := os.Getenv(EnvDisableUserspaceRouting); val != "" {
disableUspRouting, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvDisableUserspaceRouting, err)
}
}
if val := os.Getenv(EnvForceUserspaceRouter); val != "" {
forceUserspaceRouter, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvForceUserspaceRouter, err)
}
}

switch {
case disableUspRouting:
m.routingEnabled = false
m.nativeRouter = false
log.Info("userspace routing is disabled")

case disableServerRoutes:
case m.disableServerRoutes:
// if server routes are disabled we will let packets pass to the native stack
m.routingEnabled = true
m.nativeRouter = true
Expand Down Expand Up @@ -252,40 +283,45 @@ func (m *Manager) determineRouting(iface common.IFaceMapper, disableServerRoutes
log.Info("userspace routing enabled by default")
}

// netstack needs the forwarder for local traffic
if m.netstack && m.localForwarding ||
m.routingEnabled && !m.nativeRouter {

m.initForwarder(iface)
if m.routingEnabled && !m.nativeRouter {
return m.initForwarder()
}

return nil
}

// initForwarder initializes the forwarder, it disables routing on errors
func (m *Manager) initForwarder(iface common.IFaceMapper) {
func (m *Manager) initForwarder() error {
if m.forwarder != nil {
return nil
}

// Only supported in userspace mode as we need to inject packets back into wireguard directly
intf := iface.GetWGDevice()
intf := m.wgIface.GetWGDevice()
if intf == nil {
log.Info("forwarding not supported")
m.routingEnabled = false
return
return errors.New("forwarding not supported")
}

forwarder, err := forwarder.New(iface, m.logger, m.netstack)
forwarder, err := forwarder.New(m.wgIface, m.logger, m.netstack)
if err != nil {
log.Errorf("failed to create forwarder: %v", err)
m.routingEnabled = false
return
return fmt.Errorf("create forwarder: %w", err)
}

m.forwarder = forwarder

log.Debug("forwarder initialized")

return nil
}

func (m *Manager) Init(*statemanager.Manager) error {
return nil
}

func (m *Manager) IsServerRouteSupported() bool {
return m.nativeFirewall != nil || m.routingEnabled && m.forwarder != nil
return true
}

func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
Expand Down Expand Up @@ -586,7 +622,6 @@ func (m *Manager) dropFilter(packetData []byte) bool {
defer m.decoders.Put(d)

if !m.isValidPacket(d, packetData) {
m.logger.Trace("Invalid packet structure")
return true
}

Expand Down Expand Up @@ -658,11 +693,9 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
return false
}

// Get protocol and ports for route ACL check
proto := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)

// Check route ACLs
if !m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) {
m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v",
srcIP, srcPort, dstIP, dstPort, proto)
Expand Down Expand Up @@ -704,12 +737,12 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {

func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
log.Tracef("couldn't decode layer, err: %s", err)
m.logger.Trace("couldn't decode packet, err: %s", err)
return false
}

if len(d.decoded) < 2 {
log.Tracef("not enough levels in network packet")
m.logger.Trace("packet doesn't have network and transport layers")
return false
}
return true
Expand Down Expand Up @@ -953,3 +986,34 @@ func (m *Manager) SetLogLevel(level log.Level) {
m.logger.SetLevel(nblog.Level(level))
}
}

func (m *Manager) EnableRouting() error {
m.mutex.Lock()
defer m.mutex.Unlock()

return m.determineRouting()
}

func (m *Manager) DisableRouting() error {
m.mutex.Lock()
defer m.mutex.Unlock()

if m.forwarder == nil {
return nil
}

m.routingEnabled = false
m.nativeRouter = false

// don't stop forwarder if in use by netstack
if m.netstack && m.localForwarding {
return nil
}

m.forwarder.Stop()
m.forwarder = nil

log.Debug("forwarder stopped")

return nil
}
1 change: 1 addition & 0 deletions client/firewall/uspfilter/uspfilter_filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
}

manager, err := Create(ifaceMock, false)
require.NoError(tb, manager.EnableRouting())
require.NoError(tb, err)
require.NotNil(tb, manager)
require.True(tb, manager.routingEnabled)
Expand Down
12 changes: 6 additions & 6 deletions client/internal/routemanager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,15 +286,15 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
m.updateClientNetworks(updateSerial, filteredClientRoutes)
m.notifier.OnNewRoutes(filteredClientRoutes)
}
m.clientRoutes = newClientRoutesIDMap

if m.serverRouter != nil {
err := m.serverRouter.updateRoutes(newServerRoutesMap)
if err != nil {
return err
}
if m.serverRouter == nil {
return nil
}

m.clientRoutes = newClientRoutesIDMap
if err := m.serverRouter.updateRoutes(newServerRoutesMap); err != nil {
return fmt.Errorf("update routes: %w", err)
}

return nil
}
Expand Down
12 changes: 9 additions & 3 deletions client/internal/routemanager/server_nonandroid.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,15 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
}

if len(m.routes) > 0 {
err := systemops.EnableIPForwarding()
if err != nil {
return err
if err := systemops.EnableIPForwarding(); err != nil {
return fmt.Errorf("enable ip forwarding: %w", err)
}
if err := m.firewall.EnableRouting(); err != nil {
return fmt.Errorf("enable routing: %w", err)
}
} else {
if err := m.firewall.DisableRouting(); err != nil {
return fmt.Errorf("disable routing: %w", err)
}
}

Expand Down

0 comments on commit b41de7f

Please sign in to comment.