Skip to content

Commit

Permalink
Refactor radix cache helper functions to use iterative approaches
Browse files Browse the repository at this point in the history
Replace recursive implementations with iterative ones using stacks in `_match_prefix_helper`, `_print_helper`, and `_total_size_helper` to prevent potential stack overflow issues. Simplify the return value handling in `_match_prefix_helper` by removing the list parameter and returning a tuple directly. Streamline the `_insert_helper` logic with a while loop for better readability and maintainability. These changes improve performance and reliability while maintaining the same functionality.
  • Loading branch information
luzengxiangcn committed Feb 5, 2025
1 parent 4885b90 commit 4447c79
Showing 1 changed file with 45 additions and 35 deletions.
80 changes: 45 additions & 35 deletions python/sglang/srt/mem_cache/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,12 @@ def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
if self.disable:
return [], self.root_node

value = []
last_node = [self.root_node]
self._match_prefix_helper(self.root_node, key, value, last_node)
value, last_node = self._match_prefix_helper(self.root_node, key)
if value:
value = torch.concat(value)
else:
value = torch.tensor([], dtype=torch.int32)
return value, last_node[0]
return value, last_node

def insert(self, key: List, value=None):
if self.disable:
Expand Down Expand Up @@ -191,7 +189,7 @@ def pretty_print(self):
print(f"#tokens: {self.total_size()}")

def total_size(self):
return self._total_size_helper(self.root_node)
return self._total_size_helper()

def evict(self, num_tokens: int, evict_callback: Callable):
if self.disable:
Expand Down Expand Up @@ -253,24 +251,23 @@ def protected_size(self):

##### Internal Helper Functions #####

def _match_prefix_helper(
self, node: TreeNode, key: List, value, last_node: TreeNode
):
def _match_prefix_helper(self, node: TreeNode, key: List):
node.last_access_time = time.time()
if len(key) == 0:
return

if key[0] in node.children.keys():
value = []
while len(key) > 0 and key[0] in node.children.keys():
child = node.children[key[0]]
child.last_access_time = time.time()
prefix_len = _key_match(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value)
last_node[0] = new_node
node = new_node
break
else:
value.append(child.value)
last_node[0] = child
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
node = child
key = key[prefix_len:]
return value, node

def _split_node(self, key, child: TreeNode, split_len: int):
# new_node -> child
Expand All @@ -291,22 +288,25 @@ def _insert_helper(self, node: TreeNode, key: List, value):
if len(key) == 0:
return 0

if key[0] in node.children.keys():
total_prefix_length = 0
while len(key) > 0 and key[0] in node.children.keys():
child = node.children[key[0]]
child.last_access_time = time.time()
prefix_len = _key_match(child.key, key)

total_prefix_length += prefix_len
if prefix_len == len(child.key):
if prefix_len == len(key):
return prefix_len
break
else:
key = key[prefix_len:]
value = value[prefix_len:]
return prefix_len + self._insert_helper(child, key, value)

new_node = self._split_node(child.key, child, prefix_len)
return prefix_len + self._insert_helper(
new_node, key[prefix_len:], value[prefix_len:]
)
node = child
else:
new_node = self._split_node(child.key, child, prefix_len)
key = key[prefix_len:]
value = value[prefix_len:]
node = new_node
break

if len(key):
new_node = TreeNode()
Expand All @@ -315,12 +315,21 @@ def _insert_helper(self, node: TreeNode, key: List, value):
new_node.value = value
node.children[key[0]] = new_node
self.evictable_size_ += len(value)
return 0
return total_prefix_length

def _print_helper(self, node: TreeNode, indent: int):
for _, child in node.children.items():
print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
self._print_helper(child, indent=indent + 2)
"""Prints the radix tree in a human-readable format."""
stack = [(node, indent)]
while stack:
current_node, current_indent = stack.pop()
print(
" " * current_indent,
len(current_node.key),
current_node.key[:10],
f"r={current_node.lock_ref}",
)
for _, child in current_node.children.items():
stack.append((child, current_indent + 2))

def _delete_leaf(self, node):
for k, v in node.parent.children.items():
Expand All @@ -329,13 +338,14 @@ def _delete_leaf(self, node):
del node.parent.children[k]
self.evictable_size_ -= len(node.key)

def _total_size_helper(self, node: TreeNode):
if node.evicted:
return 0
x = len(node.value)
for child in node.children.values():
x += self._total_size_helper(child)
return x
def _total_size_helper(self):
total_size = 0
stack = [self.root_node]
while stack:
current_node = stack.pop()
total_size += len(current_node.value)
stack.extend(current_node.children.values())
return total_size

def _collect_leaves(self):
ret_list = []
Expand Down

0 comments on commit 4447c79

Please sign in to comment.