Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: step hooks #197

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion crates/llm-chain/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//! combination of types that implement the required traits.

use crate::output::Output;
use crate::prompt;
use crate::step::Step;
use crate::traits;
use crate::traits::ExecutorError;
Expand Down Expand Up @@ -45,8 +46,19 @@ where
&self,
parameters: &Parameters,
) -> Result<Output, FormatAndExecuteError> {
if let Some(before) = self.step.before.as_ref() {
if let Err(e) = before(parameters) {
panic!("Error: In before hook, {}", e);
}
}
let prompt = self.step.format(parameters)?;
Ok(self.executor.execute(self.step.options(), &prompt).await?)
let output = self.executor.execute(self.step.options(), &prompt).await?;
if let Some(after) = self.step.after.as_ref() {
if let Err(e) = after(&output) {
panic!("Error: In after hook, {}", e);
}
}
Ok(output)
}
}

Expand All @@ -58,3 +70,22 @@ pub enum FormatAndExecuteError {
#[error("Error executing: {0}")]
Execute(#[from] ExecutorError),
}

#[cfg(test)]
mod tests {
use super::*;
// Tests for step
#[test]
fn test_step() {
let mut step = Step::for_prompt_template(prompt!("Hello, world!"));
fn spy_fn(_: &Parameters) -> Result<(), String> {
Ok(())
}
step.add_before_hook(spy_fn);

fn dummy_fn_with_error(_: &Output) -> Result<(), String> {
Err("Exit with error".to_string())
}
step.add_after_hook(dummy_fn_with_error);
}
}
62 changes: 60 additions & 2 deletions crates/llm-chain/src/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,49 @@ use crate::{chains::sequential, prompt, Parameters};

use serde::Deserialize;
use serde::Serialize;

/// The types for before and after hooks. Parameters and Output are readonly currently.
pub type BeforeStepHook = fn(&Parameters) -> Result<(), String>;
pub type AfterStepHook = fn(&Output) -> Result<(), String>;

#[derive(derive_builder::Builder, Debug, Clone, Serialize, Deserialize)]
/// A step in a chain of LLM invocations. It is a combination of a prompt and a configuration.
pub struct Step {
pub(crate) prompt: prompt::PromptTemplate,
pub(crate) options: Options,
#[serde(skip)]
pub(crate) before: Option<BeforeStepHook>,
#[serde(skip)]
pub(crate) after: Option<AfterStepHook>,
}

impl Step {
pub fn for_prompt_template(prompt: prompt::PromptTemplate) -> Self {
Self {
prompt,
options: Options::empty().clone(),
before: None,
after: None,
}
}
pub fn for_prompt_with_streaming(prompt: prompt::PromptTemplate) -> Self {
let mut options = Options::builder();
options.add_option(Opt::Stream(true));
let options = options.build();
Self { prompt, options }
Self {
prompt,
options,
before: None,
after: None,
}
}
pub fn for_prompt_and_options(prompt: prompt::PromptTemplate, options: Options) -> Self {
Self { prompt, options }
Self {
prompt,
options,
before: None,
after: None,
}
}
pub fn prompt(&self) -> &prompt::PromptTemplate {
&self.prompt
Expand All @@ -58,6 +79,19 @@ impl Step {
self.prompt.format(parameters)
}

/// Add before and after hooks to the step.
/// Before hook will be called before the parameters are fed to the prompt template.
/// After hook will be called after the output for the step is generated.
/// # Argument
/// * before/after: the hook itself
/// # Returns
/// * Ok(()) on success and Err(String) on fail
pub fn add_before_hook(&mut self, before: BeforeStepHook) {
self.before = Some(before);
}
pub fn add_after_hook(&mut self, after: AfterStepHook) {
self.after = Some(after);
}
/// Executes the step with the given parameters and executor.
/// # Arguments
/// * `parameters` - A `Parameters` object containing the input parameters for the step.
Expand All @@ -78,3 +112,27 @@ impl Step {
.await
}
}

#[cfg(test)]
mod tests {
use super::*;
// Tests for step
#[test]
fn test_add_step_hooks() {
let mut step = Step::for_prompt_template(prompt!("Hello, world!"));
assert_eq!(step.before, None);
assert_eq!(step.after, None);

fn dummy_fn(_: &Parameters) -> Result<(), String> {
Ok(())
}
step.add_before_hook(dummy_fn);

fn dummy_fn_with_error(_: &Output) -> Result<(), String> {
Err("Exit with error".to_string())
}
step.add_after_hook(dummy_fn_with_error);
assert_ne!(step.before, None);
assert_ne!(step.after, None);
}
}
3 changes: 1 addition & 2 deletions crates/llm-chain/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
//! By implementing these traits, you can set up a new model and use it in your application. Your step defines the input to the model, and your executor invokes the model and returns the output. The output of the executor is then passed to the next step in the chain, and so on.
//!

use std::{error::Error, fmt::Debug};

use crate::{
options::Options,
output::Output,
Expand All @@ -19,6 +17,7 @@ use crate::{
tokens::{PromptTokensError, TokenCount, Tokenizer, TokenizerError},
};
use async_trait::async_trait;
use std::{error::Error, fmt::Debug};

#[derive(thiserror::Error, Debug)]
#[error("unable to create executor")]
Expand Down
82 changes: 82 additions & 0 deletions crates/llm-chain/tests/test_step_hooks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use llm_chain::options::Options;
use llm_chain::output::Output;
use llm_chain::parameters;
use llm_chain::prompt;
use llm_chain::prompt::Prompt;
use llm_chain::step::Step;
use llm_chain::tokens::PromptTokensError;
use llm_chain::tokens::TokenCount;
use llm_chain::tokens::{Tokenizer, TokenizerError};
use llm_chain::traits::Executor;
use llm_chain::traits::{ExecutorCreationError, ExecutorError};
use llm_chain::Parameters;

use async_trait::async_trait;
struct MockExecutor {}
struct MockTokenizer {}

/// Mock Tokenizer implementation only for testing purposes
impl Tokenizer for MockTokenizer {
fn split_text(
&self,
_: &str,
_: usize,
_: usize,
) -> Result<Vec<String>, llm_chain::tokens::TokenizerError> {
Ok(vec!["hello,".to_string(), "world".to_string()])
}
fn to_string(
&self,
_: llm_chain::tokens::TokenCollection,
) -> Result<String, llm_chain::tokens::TokenizerError> {
Ok("hello, world".to_string())
}
fn tokenize_str(
&self,
_: &str,
) -> Result<llm_chain::tokens::TokenCollection, llm_chain::tokens::TokenizerError> {
Ok(vec![1, 2].into())
}
}

/// Mock Executor implementation only for testing purposes
#[async_trait]
impl Executor for MockExecutor {
type StepTokenizer<'a> = MockTokenizer;
fn answer_prefix(&self, _: &llm_chain::prompt::Prompt) -> Option<String> {
Some("answer".to_string())
}
async fn execute(&self, _: &Options, _: &Prompt) -> Result<Output, ExecutorError> {
Ok(Output::new_immediate("hello, world".to_string().into()))
}
fn new_with_options(_: Options) -> Result<Self, ExecutorCreationError> {
Ok(Self {})
}
fn tokens_used(&self, _: &Options, _: &Prompt) -> Result<TokenCount, PromptTokensError> {
Ok(TokenCount::new(42, 1))
}
fn max_tokens_allowed(&self, _: &Options) -> i32 {
42
}
fn get_tokenizer(&self, _: &Options) -> Result<MockTokenizer, TokenizerError> {
Ok(MockTokenizer {})
}
}

#[cfg(test)]
mod tests {
// Test for step hooks
use super::*;
#[tokio::test]
async fn test_step_hooks() {
let exec = MockExecutor {};
let mut step = Step::for_prompt_template(prompt!("Say something to {{name}}"));
fn before(p: &Parameters) -> Result<(), String> {
assert_eq!(p, &parameters!("Retep"));
assert_ne!(p, &parameters!("Mary"));
Ok(())
}
step.add_before_hook(before);
let _ = step.run(&parameters!("Retep"), &exec).await;
}
}