Skip to content

Commit

Permalink
add beginnings of nftables support + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Charlie999 committed Dec 6, 2024
1 parent d259dca commit e433107
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 2 deletions.
89 changes: 88 additions & 1 deletion cursed/modules/networking/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from ipaddress import IPv4Network, IPv6Network
import datetime
import ipaddress
import json
import typing
from enum import Enum
from typing import Protocol
from abc import abstractmethod
import nftables

import dateutil.parser
from pyroute2 import NDB, WireGuard, IPRoute
Expand Down Expand Up @@ -132,3 +134,88 @@ def __init__(self, peer_addrs, local_addrs, listen_port, peer_pubkey, own_privke
self.own_privkey = own_privkey
self.listen_port = listen_port
super().__init__(TunnelInterface.TunnelType.WIREGUARD, peer_addrs, local_addrs, auto_name=auto_name)

class NFTablesEntry(Protocol):
""" Manage firewall rules """
class RuleType(Enum):
""" Possible nftables rule types """
ADD = "add"
FLUSH = "flush"
DELETE = "delete"
REPLACE = "replace"
CREATE = "create"
INSERT = "insert"
# As of now have not included RESET as this system is more for setting up dial-in firewalls
# TODO: Review this

class ObjectType(Enum):
""" Possible nftables object types """
TABLE = "table"
SET = "set"
CHAIN = "chain"
RULE = "rule"
# TODO: The rest?

obj_type: ObjectType
# TODO: methods

class NFTablesMatch:
""" Describes a single nftables match expression """
class OperatorType(Enum):
""" nftables builtin operators """
EQUAL = "eq"
NOT_EQUAL = "ne"
LESS_THAN = "lt"
GREATER_THAN = "gt"
LESS_EQUAL = "le"
GREATER_EQUAL = "ge"
NONE = "" # TODO: determine whether this is actually needed!

left: str
right: str
op: OperatorType

def __init__(self, left: str, op: OperatorType, right: str):
""" Generic constructor for match expression """
self.left = left
self.right = right
self.op = op

def convert_to_dict(self):
""" Convert into libnftables JSON schema compliant format """
return {"left":self.left, "right":self.right, "op":self.op.value}

class NFTablesStatement:
""" Describes an nftables statement """
class StatementType(Enum):
""" nftables statement types """
ACCEPT = {"name":"accept","needs_extra":False}
DROP = {"name":"drop","needs_extra":False}
QUEUE = {"name":"accept","needs_extra":None} # None means it **CAN** have extra
CONTINUE = {"name":"continue","needs_extra":False}
RETURN = {"name":"return","needs_extra":False}
JUMP = {"name":"jump","needs_extra":True}
GOTO = {"name":"goto","needs_extra":True}
REJECT = {"name":"reject","needs_extra":None}
COUNTER = {"name":"counter","needs_extra":None}
LIMIT = {"name":"limit","needs_extra":True}
DNAT = {"name":"dnat","needs_extra":True}
SNAT = {"name":"snat","needs_extra":True}
MASQUERADE = {"name":"masquerade","needs_extra":False}

s_type: StatementType
extra: str # For types that need it, for instance `dnat [[to 192.168.1.1]]` == {type: DNAT, extra: "to 192.168.1.1"}

def __init__(self, s_type: StatementType, extra: typing.Optional[str] = None):
""" Generic constructor """#
if s_type.value["needs_extra"] is not False and extra is None:
raise ValueError("Extra information must be provided for this statment type: "+s_type.name)
if s_type.value["needs_extra"] is False and extra is not None:
raise ValueError("Extra information must not be provided for this statement type: "+s_type.name)
self.s_type = s_type
self.extra = extra

def convert_to_dict(self):
""" Convert into libnftables JSON schema compliant format """
return {self.s_type.value["name"]: self.extra}

21 changes: 20 additions & 1 deletion cursed/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pyroute2 import NDB, WireGuard, IPRoute

class TestWireguardTunnel(unittest.TestCase):
def test_wg_lifecycle(self):
def test_wg_interface_lifecycle(self):
import modules.networking
tun = modules.networking.WireguardTunnel(
[ipaddress.ip_network("192.0.0.2/32")],
Expand Down Expand Up @@ -34,6 +34,25 @@ def test_wg_lifecycle(self):
interface = ipr.link_lookup(ifname=tun.ifname)
self.assertEqual(len(interface), 0) # Have we succesfully got rid of it?

class TestNFTablesFeatures(unittest.TestCase):
def test_nftables_build_simple_statement(self):
from modules.networking import NFTablesStatement
statement = NFTablesStatement(NFTablesStatement.StatementType.ACCEPT)

self.assertEqual(statement.convert_to_dict(), {"accept":None})

def test_nftables_build_complex_statement(self):
from modules.networking import NFTablesStatement
statement = NFTablesStatement(NFTablesStatement.StatementType.REJECT,"with icmpv6 type no-route")

self.assertEqual(statement.convert_to_dict(), {"reject":"with icmpv6 type no-route"})

def test_nftables_match(self):
from modules.networking import NFTablesMatch
match = NFTablesMatch("ip length", NFTablesMatch.OperatorType.EQUAL, "1000")

self.assertEqual(match.convert_to_dict(), {"left":"ip length","right":"1000","op":"eq"})

# Run the tests!
if __name__ == '__main__':
unittest.main()

0 comments on commit e433107

Please sign in to comment.