From 2daa81c8ce56d436556360cfcd8065f8049222ab Mon Sep 17 00:00:00 2001 From: skcd Date: Thu, 28 Nov 2024 15:59:33 +0000 Subject: [PATCH] [sidecar] sculpt out run node --- sidecar/src/mcts/action_node.rs | 103 ++++++++++++++++++++++++++++---- 1 file changed, 93 insertions(+), 10 deletions(-) diff --git a/sidecar/src/mcts/action_node.rs b/sidecar/src/mcts/action_node.rs index badc91864..59256f1a0 100644 --- a/sidecar/src/mcts/action_node.rs +++ b/sidecar/src/mcts/action_node.rs @@ -1,8 +1,11 @@ use std::collections::HashMap; -use color_eyre::owo_colors::OwoColorize; +use llm_client::clients::types::LLMClientMessage; -use crate::agentic::tool::{input::ToolInputPartial, r#type::ToolType}; +use crate::{ + agentic::tool::{input::ToolInputPartial, r#type::ToolType}, + user_context::types::UserContext, +}; use super::{selector::selector::Selector, value_function::reward::Reward}; @@ -13,30 +16,38 @@ pub struct ActionObservation { expect_correction: bool, } +/// how do we get the action nodes to be part of the llm inference where we can generate +/// more steps if required etc, thats the important bit here pub struct ActionNode { index: usize, - _action: Option, - _feedback: Option, - _is_duplicate: bool, + action: Option, + feedback: Option, + is_duplicate: bool, reward: Option, visits: u32, value: f32, max_expansions: usize, observation: Option, + // this tracks the context associated with the current action node + user_context: UserContext, + // the message associated with the node + message: Option, } impl ActionNode { pub fn new(index: usize, max_expansions: usize) -> Self { Self { index, - _action: None, - _feedback: None, - _is_duplicate: false, + action: None, + feedback: None, + is_duplicate: false, reward: None, visits: 0, value: 0.0, max_expansions, observation: None, + user_context: UserContext::default(), + message: None, } } @@ -60,6 +71,21 @@ impl ActionNode { .map(|observation| observation.terminal) .unwrap_or_default() } + + fn reset(&mut self) { + self.reward = None; + self.visits = 0; + self.value = 0.0; + self.observation = None; + self.is_duplicate = false; + self.feedback = None; + self.action = None; + } + + /// Get the message figured out properly over here + fn to_messages(nodes: Vec<&Self>) -> Vec { + vec![] + } } pub struct SearchTree { @@ -73,6 +99,7 @@ pub struct SearchTree { /// maximum depth the nodes can go to max_depth: u32, selector: Selector, + tools: Vec, } impl SearchTree { @@ -278,7 +305,7 @@ impl SearchTree { && node_children .to_vec() .into_iter() - .filter_map(|child| child._action.clone()) + .filter_map(|child| child.action.clone()) .map(|tool_parameters| tool_parameters.to_tool_type()) .any(|tool_type| { bad_child_actions @@ -487,7 +514,7 @@ impl SearchTree { return true; } let node = node.expect("if let None to hold"); - node._is_duplicate + node.is_duplicate } /// Recursively grabs all the expandable node starting from the root @@ -594,4 +621,60 @@ impl SearchTree { self.add_node_to_parent(node_index, child_node_index); Some(child_node_index) } + + fn reset_children_for_node(&mut self, node_index: usize) { + let node = self.get_node(node_index); + if let None = node { + return; + } + let node = node.expect("if let None to hold"); + let children_indices = self.children_indices(node).unwrap_or_default(); + // remove all the child edges node_to_childres + self.node_to_children + .get_mut(&node_index) + .map(|children_indices| children_indices.clear()); + // remove all the parent edges node_to_parent + children_indices.into_iter().for_each(|child_index| { + self.node_to_parent.remove(&child_index); + }); + } + + fn get_trajectory(&self, node_index: usize) -> Vec<&ActionNode> { + let node = self.get_node(node_index); + if let None = node { + vec![] + } else { + let node = node.expect("if let None to hold"); + let mut nodes = vec![node]; + let parent_node = self.parent(node); + match parent_node { + Some(parent_node) => { + nodes.extend(self.get_trajectory(parent_node.index)); + nodes + } + None => nodes, + } + } + } + + pub fn run_node(&mut self, node_index: usize) { + let node = self.get_node_mut(node_index); + if let None = node { + return; + } + let node = node.expect("if let None to hold"); + // reset the node + node.reset(); + // reset the graph at this node as well + self.reset_children_for_node(node_index); + + // first we generate the message which we want to run inference for the + // trajectory + let node_trajectory = self.get_trajectory(node_index); + + // pick the next action we want to take over here + // - execute the action + // - add the observation to the node + // - generate the value reward + } }