Skip to content

Commit

Permalink
Merge pull request #1579 from codestoryai/features/sculpt-out-run-node
Browse files Browse the repository at this point in the history
[sidecar] sculpt out run node
  • Loading branch information
theskcd authored Nov 28, 2024
2 parents 8dce32f + 2daa81c commit f62abea
Showing 1 changed file with 93 additions and 10 deletions.
103 changes: 93 additions & 10 deletions sidecar/src/mcts/action_node.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use std::collections::HashMap;

use color_eyre::owo_colors::OwoColorize;
use llm_client::clients::types::LLMClientMessage;

use crate::agentic::tool::{input::ToolInputPartial, r#type::ToolType};
use crate::{
agentic::tool::{input::ToolInputPartial, r#type::ToolType},
user_context::types::UserContext,
};

use super::{selector::selector::Selector, value_function::reward::Reward};

Expand All @@ -13,30 +16,38 @@ pub struct ActionObservation {
expect_correction: bool,
}

/// how do we get the action nodes to be part of the llm inference where we can generate
/// more steps if required etc, thats the important bit here
pub struct ActionNode {
index: usize,
_action: Option<ToolInputPartial>,
_feedback: Option<String>,
_is_duplicate: bool,
action: Option<ToolInputPartial>,
feedback: Option<String>,
is_duplicate: bool,
reward: Option<Reward>,
visits: u32,
value: f32,
max_expansions: usize,
observation: Option<ActionObservation>,
// this tracks the context associated with the current action node
user_context: UserContext,
// the message associated with the node
message: Option<String>,
}

impl ActionNode {
pub fn new(index: usize, max_expansions: usize) -> Self {
Self {
index,
_action: None,
_feedback: None,
_is_duplicate: false,
action: None,
feedback: None,
is_duplicate: false,
reward: None,
visits: 0,
value: 0.0,
max_expansions,
observation: None,
user_context: UserContext::default(),
message: None,
}
}

Expand All @@ -60,6 +71,21 @@ impl ActionNode {
.map(|observation| observation.terminal)
.unwrap_or_default()
}

fn reset(&mut self) {
self.reward = None;
self.visits = 0;
self.value = 0.0;
self.observation = None;
self.is_duplicate = false;
self.feedback = None;
self.action = None;
}

/// Get the message figured out properly over here
fn to_messages(nodes: Vec<&Self>) -> Vec<LLMClientMessage> {
vec![]
}
}

pub struct SearchTree {
Expand All @@ -73,6 +99,7 @@ pub struct SearchTree {
/// maximum depth the nodes can go to
max_depth: u32,
selector: Selector,
tools: Vec<ToolType>,
}

impl SearchTree {
Expand Down Expand Up @@ -278,7 +305,7 @@ impl SearchTree {
&& node_children
.to_vec()
.into_iter()
.filter_map(|child| child._action.clone())
.filter_map(|child| child.action.clone())
.map(|tool_parameters| tool_parameters.to_tool_type())
.any(|tool_type| {
bad_child_actions
Expand Down Expand Up @@ -487,7 +514,7 @@ impl SearchTree {
return true;
}
let node = node.expect("if let None to hold");
node._is_duplicate
node.is_duplicate
}

/// Recursively grabs all the expandable node starting from the root
Expand Down Expand Up @@ -594,4 +621,60 @@ impl SearchTree {
self.add_node_to_parent(node_index, child_node_index);
Some(child_node_index)
}

fn reset_children_for_node(&mut self, node_index: usize) {
let node = self.get_node(node_index);
if let None = node {
return;
}
let node = node.expect("if let None to hold");
let children_indices = self.children_indices(node).unwrap_or_default();
// remove all the child edges node_to_childres
self.node_to_children
.get_mut(&node_index)
.map(|children_indices| children_indices.clear());
// remove all the parent edges node_to_parent
children_indices.into_iter().for_each(|child_index| {
self.node_to_parent.remove(&child_index);
});
}

fn get_trajectory(&self, node_index: usize) -> Vec<&ActionNode> {
let node = self.get_node(node_index);
if let None = node {
vec![]
} else {
let node = node.expect("if let None to hold");
let mut nodes = vec![node];
let parent_node = self.parent(node);
match parent_node {
Some(parent_node) => {
nodes.extend(self.get_trajectory(parent_node.index));
nodes
}
None => nodes,
}
}
}

pub fn run_node(&mut self, node_index: usize) {
let node = self.get_node_mut(node_index);
if let None = node {
return;
}
let node = node.expect("if let None to hold");
// reset the node
node.reset();
// reset the graph at this node as well
self.reset_children_for_node(node_index);

// first we generate the message which we want to run inference for the
// trajectory
let node_trajectory = self.get_trajectory(node_index);

// pick the next action we want to take over here
// - execute the action
// - add the observation to the node
// - generate the value reward
}
}

0 comments on commit f62abea

Please sign in to comment.