Source code for ipprefixtrie.ipprefixtrie

# -*- coding: utf-8 -*-
#
# This file is part of IPPrefixTrie.
#
# Copyright (C) 2025 Interstellio IO (PTY) LTD.
#
# IPPrefixTrie is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or any later version.

# IPPrefixTrie is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.

# You should have received a copy of the GNU Lesser General Public License
# along with IPPrefixTrie. If not, see https://www.gnu.org/licenses/.
from typing import Any
from typing import Generator
import ipaddress

from .exceptions import (InvalidPrefixError,
                         PrefixNotFoundError)


class _IPPrefixTrieNode(object):
    """
    Internal node class for the IPPrefixTrie.

    Attributes:
        left (_IPPrefixTrieNode): Left child node.
        right (_IPPrefixTrieNode): Right child node.
        is_prefix (bool): Indicates if the node represents a valid prefix.
        metadata (Any): Metadata associated with the prefix.
    """
    __slots__ = ("left", "right", "is_prefix", "metadata")

    def __init__(self):
        self.left = None
        self.right = None
        self.is_prefix = False
        self.metadata = None


def _iterate_bits(data: bytes) -> Generator[tuple[int, int, bytes],
                                            None, None]:
    len_in_bytes = len(data)                  # Length of the data in bytes.
    matched_bytes = bytearray(len_in_bytes)   # Ensure full-length padding.

    for bit_pos in range(len_in_bytes * 8):
        byte_index = bit_pos // 8       # Get Byte
        bit_index = 7 - (bit_pos % 8)   # Get Bit in Byte
        bit = (data[byte_index] >> bit_index) & 1

        # Preserve the exact bytes for the matched prefix
        matched_bytes[byte_index] |= (bit << bit_index)

        # Yielding bit position, actual bit and matched bytes thus far.
        yield bit_pos, bit, bytes(matched_bytes)


[docs] class IPPrefixTrie(object): """ A binary trie for storing and searching IP prefixes efficiently. """ __slots__ = ("__ipv4_root", "__ipv6_root") def __init__(self): self.clear()
[docs] def clear(self): """Initializes an IP prefix trie. Separate roots for IPv4 and IPv6 prefixes. """ self.__ipv4_root = _IPPrefixTrieNode() self.__ipv6_root = _IPPrefixTrieNode()
[docs] def insert(self, prefix: str, metadata=None) -> None: """Inserts an IP prefix into the trie. Args: prefix (str): The IPv4 or Ipv6 prefix in CIDR notation. raise_error (bool, optional): If True, raises an error if the prefix is not found. Defaults to True. Raises: InvalidPrefixError: If the prefix format is invalid. """ try: prefix = ipaddress.ip_network(prefix) except ValueError as e: raise InvalidPrefixError(str(e)) from None prefix_len = prefix.prefixlen prefix_bin = prefix.network_address.packed if prefix.version == 4: node = self.__ipv4_root else: node = self.__ipv6_root for bit_pos, bit, matched_bytes in _iterate_bits(prefix_bin): if bit_pos >= prefix_len: break if bit == 0: # Insert Left child node. if node.left is None: node.left = _IPPrefixTrieNode() node = node.left elif bit == 1: # Insert Right child node. if node.right is None: node.right = _IPPrefixTrieNode() node = node.right node.is_prefix = True node.metadata = metadata or {}
[docs] def get_exact(self, prefix: str, raise_error=True) -> tuple[str, Any] | None: """Retrieves an exact prefix match. Args: prefix (str): The IPv4 or Ipv6 prefix in CIDR notation. raise_error (bool, optional): If True, raises an error if the prefix is not found. Defaults to True. Raises: InvalidPrefixError: If the prefix format is invalid. PrefixNotFoundError: If the prefix is not found and `raise_error` is True. Returns: tuple[str, Any] | None: A tuple containing the prefix as a string and its associated metadata if found, otherwise None. """ try: prefix = ipaddress.ip_network(prefix) except ValueError as e: raise InvalidPrefixError(str(e)) from None prefix_len = prefix.prefixlen prefix_bin = prefix.network_address.packed if prefix.version == 4: node = self.__ipv4_root else: node = self.__ipv6_root # Iterate through each byte in the prefix for bit_pos, bit, matched_bytes in _iterate_bits(prefix_bin): if bit_pos >= prefix_len: break if bit == 0: # Insert Left child node. if node.left is None: if raise_error: raise PrefixNotFoundError(str(prefix)) return node = node.left elif bit == 1: # Insert Right child node. if node.right is None: if raise_error: raise PrefixNotFoundError(str(prefix)) return node = node.right if node and node.is_prefix: return str(prefix), node.metadata elif raise_error: raise PrefixNotFoundError(str(prefix))
[docs] def get_longest(self, prefix: str, raise_error=True) -> tuple[str, Any] | None: """Finds the longest matching prefix. Args: prefix (str): The IPv4 or Ipv6 prefix in CIDR notation. raise_error (bool, optional): If True, raises an error if the prefix is not found. Defaults to True. Raises: InvalidPrefixError: If the prefix format is invalid. PrefixNotFoundError: If the match is not found and `raise_error` is True. Returns: tuple[str, Any] | None: A tuple containing the prefix as a string and its associated metadata if found, otherwise None. """ try: prefix = ipaddress.ip_network(prefix) except ValueError as e: raise InvalidPrefixError(str(e)) from None prefix_bin = prefix.network_address.packed if prefix.version == 4: node = self.__ipv4_root else: node = self.__ipv6_root longest_match_node = None longest_match_length = 0 for bit_pos, bit, matched_bytes in _iterate_bits(prefix_bin): if bit == 0: if node.left is None: break node = node.left else: if node.right is None: break node = node.right if node.is_prefix: longest_match_node = node longest_match_length = bit_pos + 1 if longest_match_node: network_address = ipaddress.ip_address(bytes(matched_bytes)) prefix_str = f"{network_address}/{longest_match_length}" return prefix_str, longest_match_node.metadata return None
[docs] def get_orlonger(self, prefix: str) -> Generator[tuple[str, Any], None, None] | None: """Yields orlonger prefixes. Args: prefix (str): The IPv4 or Ipv6 prefix in CIDR notation. Raises: InvalidPrefixError: If the prefix format is invalid. Yields: tuple: (matching prefix as str, metadata) """ try: prefix = ipaddress.ip_network(prefix) except ValueError as e: raise InvalidPrefixError(str(e)) from None prefix_bin = prefix.network_address.packed if prefix.version == 4: node = self.__ipv4_root else: node = self.__ipv6_root matched_bytes = bytearray(len(prefix_bin)) bit_pos = 0 # Step 1: Find the given prefix in the trie for bit_pos, bit, matched_bytes in _iterate_bits(prefix_bin): if bit_pos >= prefix.prefixlen: break if bit == 0: # Insert Left child node. if node.left is None: return node = node.left elif bit == 1: # Insert Right child node. if node.right is None: return node = node.right # Step 2: Traverse tree to yield all and more specific (child) prefixes queue = [(node, bit_pos, matched_bytes[:])] while queue: node, bit_pos, matched_prefix = queue.pop(0) if node.is_prefix: network_address = ipaddress.ip_address(bytes(matched_prefix)) yield f"{network_address}/{bit_pos}", node.metadata if bit_pos < prefix.max_prefixlen: byte_index = bit_pos // 8 bit_index = 7 - (bit_pos % 8) if node.left: left_prefix = bytearray(matched_prefix) left_prefix[byte_index] &= ~(1 << bit_index) queue.append((node.left, bit_pos + 1, left_prefix)) if node.right: right_prefix = bytearray(matched_prefix) right_prefix[byte_index] |= (1 << bit_index) queue.append((node.right, bit_pos + 1, right_prefix))
[docs] def delete(self, prefix: str, raise_error=True) -> bool: """Deletes the given prefix from the trie. If it has no children it will clean up nodes up to the next valid prefix. Args: prefix (str): The IPv4 or Ipv6 prefix in CIDR notation. raise_error (bool, optional): If True, raises an error if the prefix is not found. Defaults to True. Raises: InvalidPrefixError: If the prefix format is invalid. PrefixNotFoundError: If the match is not found and `raise_error` is True. Returns: bool: True if deleted, false is not found. """ try: prefix = ipaddress.ip_network(prefix) except ValueError as e: raise InvalidPrefixError(str(e)) from None prefix_bin = prefix.network_address.packed if prefix.version == 4: node = self.__ipv4_root else: node = self.__ipv6_root path_traversed = [] # Stores nodes visited along the path. for bit_pos, bit, matched_bytes in _iterate_bits(prefix_bin): if bit_pos >= prefix.prefixlen: break # Store node reference and bit direction in traversed path. path_traversed.append((node, bit)) if bit == 0: if node.left is None: if raise_error: raise PrefixNotFoundError(str(prefix)) return False # Prefix not found node = node.left else: if node.right is None: if raise_error: raise PrefixNotFoundError(str(prefix)) return False # Prefix not found node = node.right if not node.is_prefix: if raise_error: raise PrefixNotFoundError(str(prefix)) return False # Prefix not found # Unset the prefix flag and remove metadata node.is_prefix = False node.metadata = None # Cleanup unnecessary nodes while path_traversed: parent, bit = path_traversed.pop() if (bit == 0 and parent.left and not parent.left.is_prefix and not parent.left.left and not parent.left.right): parent.left = None elif (bit == 1 and parent.right and not parent.right.is_prefix and not parent.right.left and not parent.right.right): parent.right = None else: break # Stop cleanup if we hit a valid prefix return True