Skip to content

Commit

Permalink
feat: add MAC rule
Browse files Browse the repository at this point in the history
  • Loading branch information
xishang0128 committed Aug 10, 2024
1 parent ae98c23 commit afe2880
Show file tree
Hide file tree
Showing 10 changed files with 421 additions and 0 deletions.
67 changes: 67 additions & 0 deletions component/arp/arp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Package arp provides a general interface to retrieve the ARP table
// on both Linux and Windows

// kanged from https://github.com/situation-sh/situation/blob/main/modules/arp/arp.go

package arp

import (
"net"
"net/netip"
"sync"

"github.com/metacubex/mihomo/log"
)

var (
table []ARPEntry
mu sync.Mutex
)

type ARPEntry struct {
MAC net.HardwareAddr
IP net.IP
}

func IsReserved(ip net.IP) bool {
if ip4 := ip.To4(); ip4 != nil {
return ip4[3] == 0 || ip4[3] == 255
}
return false
}

func refreshARPTable() {
var err error
newTable, err := GetARPTable()
if err != nil {
log.Warnln("failed to refresh ARP table")
return
}
mu.Lock()
table = newTable
mu.Unlock()
}

func IPToMac(ip netip.Addr) net.HardwareAddr {
mu.Lock()
defer mu.Unlock()

if len(table) == 0 {
refreshARPTable()
}

for _, entry := range table {
if entry.IP.Equal(ip.AsSlice()) {
return entry.MAC
}
}
if ip.IsPrivate() {
refreshARPTable()
for _, entry := range table {
if entry.IP.Equal(ip.AsSlice()) {
return entry.MAC
}
}
}
return nil
}
59 changes: 59 additions & 0 deletions component/arp/arp_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package arp

import (
"fmt"
"net"

"github.com/sagernet/netlink"
)

func neighMAC(n netlink.Neigh) net.HardwareAddr {
length := len(n.HardwareAddr)
mac := make(net.HardwareAddr, length)
copy(mac, n.HardwareAddr)
return mac
}

func neighIP(n netlink.Neigh) net.IP {
length := len(n.IP)
ip := make(net.IP, length)
copy(ip, n.IP)
return ip
}

func neighToARPEntry(n netlink.Neigh) ARPEntry {
return ARPEntry{
MAC: neighMAC(n),
IP: neighIP(n),
}
}

func GetARPTable() ([]ARPEntry, error) {
entries := make([]ARPEntry, 0)

links, err := netlink.LinkList()
if err != nil {
return nil, err
}

for _, link := range links {
attr := link.Attrs()
neighs, err := netlink.NeighList(attr.Index, 0)
if err != nil {
fmt.Println(err)
continue
}
for _, neigh := range neighs {
entry := neighToARPEntry(neigh)

if IsReserved(entry.IP) {
continue
}

if entry.IP.IsGlobalUnicast() {
entries = append(entries, entry)
}
}
}
return entries, nil
}
7 changes: 7 additions & 0 deletions component/arp/arp_other.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
//go:build !linux && !windows

package arp

func GetARPTable() ([]ARPEntry, error) {
return nil, nil
}
22 changes: 22 additions & 0 deletions component/arp/arp_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package arp

func GetARPTable() ([]ARPEntry, error) {
table, err := GetIpNetTable2()
if err != nil {
return nil, err
}
entries := make([]ARPEntry, 0)
for _, row := range table {
entry := row.ToARPEntry()

// ignore 0 and 255 in case of IPv4
if IsReserved(entry.IP) {
continue
}

if entry.IP.IsGlobalUnicast() {
entries = append(entries, entry)
}
}
return entries, nil
}
52 changes: 52 additions & 0 deletions component/arp/get_ip_net_table2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//go:build windows
// +build windows

package arp

import (
"fmt"
"syscall"
"unsafe"

"golang.org/x/sys/windows"
)

var iphlpapi *windows.DLL

func init() {
iphlpapi = windows.MustLoadDLL("Iphlpapi.dll")
}

func GetIpNetTable2() (MIBIpNetTable2, error) {
proc, err := iphlpapi.FindProc("GetIpNetTable2")
if err != nil {
return nil, err
}

free, err := iphlpapi.FindProc("FreeMibTable")
if err != nil {
return nil, err
}

var data *rawMIBIpNetTable2
errno, _, _ := proc.Call(0, uintptr(unsafe.Pointer(&data)))
defer free.Call(uintptr(unsafe.Pointer(data)))

switch syscall.Errno(errno) {
case windows.ERROR_SUCCESS:
err = nil
case windows.ERROR_NOT_ENOUGH_MEMORY:
err = fmt.Errorf("insufficient memory resources are available to complete the operation")
case windows.ERROR_INVALID_PARAMETER:
err = fmt.Errorf("an invalid parameter was passed to the function")
case windows.ERROR_NOT_FOUND:
err = fmt.Errorf("no neighbor IP address entries as specified in the Family parameter were found")
case windows.ERROR_NOT_SUPPORTED:
err = fmt.Errorf("the IPv4 or IPv6 transports are not configured on the local computer")
default:
err = windows.GetLastError()
}

table := data.parse()
return table, err
}
143 changes: 143 additions & 0 deletions component/arp/mib_ipnet_row2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
//go:build windows
// +build windows

package arp

import (
"encoding/binary"
"net"
"time"
)

const MIBIpNetRow2Size = 88
const SockAddrSize = 28

type SockAddrIn struct {
sinFamily uint16
sinPort uint16
sinAddr net.IP
sinZero []byte
}

func NewSockAddrIn(buffer []byte) SockAddrIn {
addr := SockAddrIn{
sinFamily: binary.LittleEndian.Uint16(buffer[:2]),
sinPort: binary.LittleEndian.Uint16(buffer[2:4]),
sinAddr: net.IP(make([]byte, 4)).To4(),
sinZero: make([]byte, 8),
}
copy(addr.sinAddr, buffer[4:8])
copy(addr.sinZero, buffer[8:16])
return addr
}

func (s SockAddrIn) Family() uint16 {
return s.sinFamily
}

func (s SockAddrIn) Addr() net.IP {
return s.sinAddr.To4()
}

type SockAddrIn6 struct {
sin6Family uint16
sin6Port uint16
sin6FlowInfo uint32
sin6Addr net.IP
sin6ScopeId uint32
}

func NewSockAddrIn6(buffer []byte) SockAddrIn6 {
addr := SockAddrIn6{
sin6Family: binary.LittleEndian.Uint16(buffer[:2]),
sin6Port: binary.LittleEndian.Uint16(buffer[2:4]),
sin6FlowInfo: binary.LittleEndian.Uint32(buffer[4:8]),
sin6Addr: net.IP(make([]byte, 16)).To16(),
sin6ScopeId: binary.LittleEndian.Uint32(buffer[24:28]),
}
copy(addr.sin6Addr, buffer[8:24])
return addr
}

func (s SockAddrIn6) Family() uint16 {
return s.sin6Family
}

func (s SockAddrIn6) Addr() net.IP {
return s.sin6Addr.To16()
}

type SockAddr interface {
Family() uint16
Addr() net.IP
}

func parseSockAddr(buffer []byte) SockAddr {
sockType := binary.LittleEndian.Uint16(buffer[:2])
switch sockType {
case 2: // IPv4
return NewSockAddrIn(buffer[:SockAddrSize])
case 23: // IPv6
return NewSockAddrIn6(buffer[:SockAddrSize])
default:
return nil
}
}

func parsePhysicalAddress(buffer []byte, physicalAddressLength uint32) net.HardwareAddr {
pa := make(net.HardwareAddr, physicalAddressLength)
copy(pa, buffer[:physicalAddressLength])
return pa
}

type MIBIpNetRow2 struct {
address SockAddr
interfaceIndex uint32
interfaceLuid uint64
physicalAddress net.HardwareAddr
physicalAddressLength uint32
flags uint32
reachabilityTime time.Duration
}

func (r MIBIpNetRow2) MAC() net.HardwareAddr {
mac := make(net.HardwareAddr, r.physicalAddressLength)
copy(mac, r.physicalAddress)
return mac
}

func (r MIBIpNetRow2) IP() net.IP {
length := len(r.address.Addr())
ip := make(net.IP, length)
copy(ip, r.address.Addr())
return ip
}

func (r MIBIpNetRow2) ToARPEntry() ARPEntry {
return ARPEntry{
MAC: r.MAC(),
IP: r.IP(),
}
}

type rawMIBIpNetRow2 struct {
address [28]byte
interfaceIndex uint32
interfaceLuid uint64
physicalAddress [32]byte
physicalAddressLength uint32
flags uint32
reachabilityTime uint32
}

func (r rawMIBIpNetRow2) Parse() MIBIpNetRow2 {
return MIBIpNetRow2{
address: parseSockAddr(r.address[:]),
interfaceIndex: r.interfaceIndex,
interfaceLuid: r.interfaceLuid,
physicalAddress: parsePhysicalAddress(r.physicalAddress[:], r.physicalAddressLength),
physicalAddressLength: r.physicalAddressLength,
flags: r.flags,
reachabilityTime: time.Duration(r.reachabilityTime * uint32(time.Millisecond)),
}
}
22 changes: 22 additions & 0 deletions component/arp/mib_ipnet_table2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//go:build windows
// +build windows

package arp

const anySize = 1 << 16

type MIBIpNetTable2 []MIBIpNetRow2

type rawMIBIpNetTable2 struct {
numEntries uint32
padding uint32
table [anySize]rawMIBIpNetRow2
}

func (r *rawMIBIpNetTable2) parse() MIBIpNetTable2 {
t := make([]MIBIpNetRow2, r.numEntries)
for i := 0; i < int(r.numEntries); i++ {
t[i] = r.table[i].Parse()
}
return t
}
3 changes: 3 additions & 0 deletions constant/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const (
DstPort
InPort
DSCP
Mac
InUser
InName
InType
Expand Down Expand Up @@ -98,6 +99,8 @@ func (rt RuleType) String() string {
return "Uid"
case SubRules:
return "SubRules"
case Mac:
return "Mac"
case AND:
return "AND"
case OR:
Expand Down
Loading

0 comments on commit afe2880

Please sign in to comment.