From dd149dfe5d01c658b53c27e7b422cf307c2f0758 Mon Sep 17 00:00:00 2001
From: Miauwkeru <Miauwkeru@users.noreply.github.com>
Date: Wed, 9 Oct 2024 11:22:34 +0000
Subject: [PATCH] Change type definitions

---
 flow/record/fieldtypes/net/ip.py | 24 +++++++++++++++++-------
 1 file changed, 17 insertions(+), 7 deletions(-)

diff --git a/flow/record/fieldtypes/net/ip.py b/flow/record/fieldtypes/net/ip.py
index 660118c..3a661e1 100644
--- a/flow/record/fieldtypes/net/ip.py
+++ b/flow/record/fieldtypes/net/ip.py
@@ -1,20 +1,30 @@
 from __future__ import annotations
 
-from ipaddress import ip_address, ip_network
+from ipaddress import (
+    IPv4Address,
+    IPv4Network,
+    IPv6Address,
+    IPv6Network,
+    ip_address,
+    ip_network,
+)
 from typing import Union
 
 from flow.record.base import FieldType
 from flow.record.fieldtypes import defang
 
+_IPNetwork = Union[IPv4Network, IPv6Network]
+_IPAddress = Union[IPv4Address, IPv6Address]
+
 
 class ipaddress(FieldType):
     val = None
     _type = "net.ipaddress"
 
-    def __init__(self, addr: Union[str, int]):
+    def __init__(self, addr: str | int | bytes):
         self.val = ip_address(addr)
 
-    def __eq__(self, b: Union[str, int]) -> bool:
+    def __eq__(self, b: str | int | bytes) -> bool:
         try:
             return self.val == ip_address(b)
         except ValueError:
@@ -46,10 +56,10 @@ class ipnetwork(FieldType):
     val = None
     _type = "net.ipnetwork"
 
-    def __init__(self, addr: Union[str, int]):
+    def __init__(self, addr: str | int | bytes):
         self.val = ip_network(addr)
 
-    def __eq__(self, b: Union[str, int]) -> bool:
+    def __eq__(self, b: str | int | bytes) -> bool:
         try:
             return self.val == ip_network(b)
         except ValueError:
@@ -59,7 +69,7 @@ def __hash__(self) -> int:
         return hash(self.val)
 
     @staticmethod
-    def _is_subnet_of(a: ip_network, b: ip_network) -> bool:
+    def _is_subnet_of(a: _IPNetwork, b: _IPNetwork) -> bool:
         try:
             # Always false if one is v4 and the other is v6.
             if a._version != b._version:
@@ -68,7 +78,7 @@ def _is_subnet_of(a: ip_network, b: ip_network) -> bool:
         except AttributeError:
             raise TypeError("Unable to test subnet containment " "between {} and {}".format(a, b))
 
-    def __contains__(self, b: Union[str, int, ip_address]) -> bool:
+    def __contains__(self, b: str | int | bytes | _IPAddress) -> bool:
         try:
             return self._is_subnet_of(ip_network(b), self.val)
         except (ValueError, TypeError):