Skip to content

Commit

Permalink
Improvements on agents package API (tmc#551)
Browse files Browse the repository at this point in the history
* feat: change Executor to *Executor

* docs: add deprecated comment to agents.Initialize

* feat: change CreationOptions to Options in agents

* feat: change CreationOption type to Option in agents
  • Loading branch information
haochunchang authored Mar 19, 2024
1 parent dc3e6f6 commit 4746a5d
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 56 deletions.
2 changes: 1 addition & 1 deletion agents/conversational.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ type ConversationalAgent struct {

var _ Agent = (*ConversationalAgent)(nil)

func NewConversationalAgent(llm llms.Model, tools []tools.Tool, opts ...CreationOption) *ConversationalAgent {
func NewConversationalAgent(llm llms.Model, tools []tools.Tool, opts ...Option) *ConversationalAgent {
options := conversationalDefaultOptions()
for _, opt := range opts {
opt(&options)
Expand Down
24 changes: 12 additions & 12 deletions agents/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,18 @@ type Executor struct {
}

var (
_ chains.Chain = Executor{}
_ callbacks.HandlerHaver = Executor{}
_ chains.Chain = &Executor{}
_ callbacks.HandlerHaver = &Executor{}
)

// NewExecutor creates a new agent executor with an agent and the tools the agent can use.
func NewExecutor(agent Agent, tools []tools.Tool, opts ...CreationOption) Executor {
func NewExecutor(agent Agent, tools []tools.Tool, opts ...Option) *Executor {
options := executorDefaultOptions()
for _, opt := range opts {
opt(&options)
}

return Executor{
return &Executor{
Agent: agent,
Tools: tools,
Memory: options.memory,
Expand All @@ -49,7 +49,7 @@ func NewExecutor(agent Agent, tools []tools.Tool, opts ...CreationOption) Execut
}
}

func (e Executor) Call(ctx context.Context, inputValues map[string]any, _ ...chains.ChainCallOption) (map[string]any, error) { //nolint:lll
func (e *Executor) Call(ctx context.Context, inputValues map[string]any, _ ...chains.ChainCallOption) (map[string]any, error) { //nolint:lll
inputs, err := inputsToString(inputValues)
if err != nil {
return nil, err
Expand All @@ -76,7 +76,7 @@ func (e Executor) Call(ctx context.Context, inputValues map[string]any, _ ...cha
), ErrNotFinished
}

func (e Executor) doIteration( // nolint
func (e *Executor) doIteration( // nolint
ctx context.Context,
steps []schema.AgentStep,
nameToTool map[string]tools.Tool,
Expand Down Expand Up @@ -118,7 +118,7 @@ func (e Executor) doIteration( // nolint
return steps, nil, nil
}

func (e Executor) doAction(
func (e *Executor) doAction(
ctx context.Context,
steps []schema.AgentStep,
nameToTool map[string]tools.Tool,
Expand Down Expand Up @@ -147,7 +147,7 @@ func (e Executor) doAction(
}), nil
}

func (e Executor) getReturn(finish *schema.AgentFinish, steps []schema.AgentStep) map[string]any {
func (e *Executor) getReturn(finish *schema.AgentFinish, steps []schema.AgentStep) map[string]any {
if e.ReturnIntermediateSteps {
finish.ReturnValues[_intermediateStepsOutputKey] = steps
}
Expand All @@ -157,20 +157,20 @@ func (e Executor) getReturn(finish *schema.AgentFinish, steps []schema.AgentStep

// GetInputKeys gets the input keys the agent of the executor expects.
// Often "input".
func (e Executor) GetInputKeys() []string {
func (e *Executor) GetInputKeys() []string {
return e.Agent.GetInputKeys()
}

// GetOutputKeys gets the output keys the agent of the executor returns.
func (e Executor) GetOutputKeys() []string {
func (e *Executor) GetOutputKeys() []string {
return e.Agent.GetOutputKeys()
}

func (e Executor) GetMemory() schema.Memory { //nolint:ireturn
func (e *Executor) GetMemory() schema.Memory { //nolint:ireturn
return e.Memory
}

func (e Executor) GetCallbackHandler() callbacks.Handler { //nolint:ireturn
func (e *Executor) GetCallbackHandler() callbacks.Handler { //nolint:ireturn
return e.CallbacksHandler
}

Expand Down
7 changes: 4 additions & 3 deletions agents/initialize.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,24 @@ const (
ConversationalReactDescription AgentType = "conversationalReactDescription"
)

// Deprecated: This may be removed in the future; please use NewExecutor instead.
// Initialize is a function that creates a new executor with the specified LLM
// model, tools, agent type, and options. It returns an Executor or an error
// if there is any issues during the creation process.
func Initialize(
llm llms.Model,
tools []tools.Tool,
agentType AgentType,
opts ...CreationOption,
) (Executor, error) {
opts ...Option,
) (*Executor, error) {
var agent Agent
switch agentType {
case ZeroShotReactDescription:
agent = NewOneShotAgent(llm, tools, opts...)
case ConversationalReactDescription:
agent = NewConversationalAgent(llm, tools, opts...)
default:
return Executor{}, ErrUnknownAgentType
return &Executor{}, ErrUnknownAgentType
}
return NewExecutor(agent, tools, opts...), nil
}
2 changes: 1 addition & 1 deletion agents/mrkl.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ var _ Agent = (*OneShotZeroAgent)(nil)
// NewOneShotAgent creates a new OneShotZeroAgent with the given LLM model, tools,
// and options. It returns a pointer to the created agent. The opts parameter
// represents the options for the agent.
func NewOneShotAgent(llm llms.Model, tools []tools.Tool, opts ...CreationOption) *OneShotZeroAgent {
func NewOneShotAgent(llm llms.Model, tools []tools.Tool, opts ...Option) *OneShotZeroAgent {
options := mrklDefaultOptions()
for _, opt := range opts {
opt(&options)
Expand Down
4 changes: 2 additions & 2 deletions agents/openai_functions_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type OpenAIFunctionsAgent struct {
var _ Agent = (*OpenAIFunctionsAgent)(nil)

// NewOpenAIFunctionsAgent creates a new OpenAIFunctionsAgent.
func NewOpenAIFunctionsAgent(llm llms.Model, tools []tools.Tool, opts ...CreationOption) *OpenAIFunctionsAgent {
func NewOpenAIFunctionsAgent(llm llms.Model, tools []tools.Tool, opts ...Option) *OpenAIFunctionsAgent {
options := openAIFunctionsDefaultOptions()
for _, opt := range opts {
opt(&options)
Expand Down Expand Up @@ -132,7 +132,7 @@ func (o *OpenAIFunctionsAgent) GetOutputKeys() []string {
return []string{o.OutputKey}
}

func createOpenAIFunctionPrompt(opts CreationOptions) prompts.ChatPromptTemplate {
func createOpenAIFunctionPrompt(opts Options) prompts.ChatPromptTemplate {
messageFormatters := []prompts.MessageFormatter{prompts.NewSystemMessagePromptTemplate(opts.systemMessage, nil)}
messageFormatters = append(messageFormatters, opts.extraMessages...)
messageFormatters = append(messageFormatters, prompts.NewHumanMessagePromptTemplate("{{.input}}", []string{"input"}))
Expand Down
74 changes: 37 additions & 37 deletions agents/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/tmc/langchaingo/tools"
)

type CreationOptions struct {
type Options struct {
prompt prompts.PromptTemplate
memory schema.Memory
callbacksHandler callbacks.Handler
Expand All @@ -25,44 +25,44 @@ type CreationOptions struct {
extraMessages []prompts.MessageFormatter
}

// CreationOption is a function type that can be used to modify the creation of the agents
// Option is a function type that can be used to modify the creation of the agents
// and executors.
type CreationOption func(*CreationOptions)
type Option func(*Options)

func executorDefaultOptions() CreationOptions {
return CreationOptions{
func executorDefaultOptions() Options {
return Options{
maxIterations: _defaultMaxIterations,
outputKey: _defaultOutputKey,
memory: memory.NewSimple(),
}
}

func mrklDefaultOptions() CreationOptions {
return CreationOptions{
func mrklDefaultOptions() Options {
return Options{
promptPrefix: _defaultMrklPrefix,
formatInstructions: _defaultMrklFormatInstructions,
promptSuffix: _defaultMrklSuffix,
outputKey: _defaultOutputKey,
}
}

func conversationalDefaultOptions() CreationOptions {
return CreationOptions{
func conversationalDefaultOptions() Options {
return Options{
promptPrefix: _defaultConversationalPrefix,
formatInstructions: _defaultConversationalFormatInstructions,
promptSuffix: _defaultConversationalSuffix,
outputKey: _defaultOutputKey,
}
}

func openAIFunctionsDefaultOptions() CreationOptions {
return CreationOptions{
func openAIFunctionsDefaultOptions() Options {
return Options{
systemMessage: "You are a helpful AI assistant.",
outputKey: _defaultOutputKey,
}
}

func (co CreationOptions) getMrklPrompt(tools []tools.Tool) prompts.PromptTemplate {
func (co Options) getMrklPrompt(tools []tools.Tool) prompts.PromptTemplate {
if co.prompt.Template != "" {
return co.prompt
}
Expand All @@ -75,7 +75,7 @@ func (co CreationOptions) getMrklPrompt(tools []tools.Tool) prompts.PromptTempla
)
}

func (co CreationOptions) getConversationalPrompt(tools []tools.Tool) prompts.PromptTemplate {
func (co Options) getConversationalPrompt(tools []tools.Tool) prompts.PromptTemplate {
if co.prompt.Template != "" {
return co.prompt
}
Expand All @@ -90,73 +90,73 @@ func (co CreationOptions) getConversationalPrompt(tools []tools.Tool) prompts.Pr

// WithMaxIterations is an option for setting the max number of iterations the executor
// will complete.
func WithMaxIterations(iterations int) CreationOption {
return func(co *CreationOptions) {
func WithMaxIterations(iterations int) Option {
return func(co *Options) {
co.maxIterations = iterations
}
}

// WithOutputKey is an option for setting the output key of the agent.
func WithOutputKey(outputKey string) CreationOption {
return func(co *CreationOptions) {
func WithOutputKey(outputKey string) Option {
return func(co *Options) {
co.outputKey = outputKey
}
}

// WithPromptPrefix is an option for setting the prefix of the prompt used by the agent.
func WithPromptPrefix(prefix string) CreationOption {
return func(co *CreationOptions) {
func WithPromptPrefix(prefix string) Option {
return func(co *Options) {
co.promptPrefix = prefix
}
}

// WithPromptFormatInstructions is an option for setting the format instructions of the prompt
// used by the agent.
func WithPromptFormatInstructions(instructions string) CreationOption {
return func(co *CreationOptions) {
func WithPromptFormatInstructions(instructions string) Option {
return func(co *Options) {
co.formatInstructions = instructions
}
}

// WithPromptSuffix is an option for setting the suffix of the prompt used by the agent.
func WithPromptSuffix(suffix string) CreationOption {
return func(co *CreationOptions) {
func WithPromptSuffix(suffix string) Option {
return func(co *Options) {
co.promptSuffix = suffix
}
}

// WithPrompt is an option for setting the prompt the agent will use.
func WithPrompt(prompt prompts.PromptTemplate) CreationOption {
return func(co *CreationOptions) {
func WithPrompt(prompt prompts.PromptTemplate) Option {
return func(co *Options) {
co.prompt = prompt
}
}

// WithReturnIntermediateSteps is an option for making the executor return the intermediate steps
// taken.
func WithReturnIntermediateSteps() CreationOption {
return func(co *CreationOptions) {
func WithReturnIntermediateSteps() Option {
return func(co *Options) {
co.returnIntermediateSteps = true
}
}

// WithMemory is an option for setting the memory of the executor.
func WithMemory(m schema.Memory) CreationOption {
return func(co *CreationOptions) {
func WithMemory(m schema.Memory) Option {
return func(co *Options) {
co.memory = m
}
}

// WithCallbacksHandler is an option for setting a callback handler to an executor.
func WithCallbacksHandler(handler callbacks.Handler) CreationOption {
return func(co *CreationOptions) {
func WithCallbacksHandler(handler callbacks.Handler) Option {
return func(co *Options) {
co.callbacksHandler = handler
}
}

// WithParserErrorHandler is an option for setting a parser error handler to an executor.
func WithParserErrorHandler(errorHandler *ParserErrorHandler) CreationOption {
return func(co *CreationOptions) {
func WithParserErrorHandler(errorHandler *ParserErrorHandler) Option {
return func(co *Options) {
co.errorHandler = errorHandler
}
}
Expand All @@ -167,14 +167,14 @@ func NewOpenAIOption() OpenAIOption {
return OpenAIOption{}
}

func (o OpenAIOption) WithSystemMessage(msg string) CreationOption {
return func(co *CreationOptions) {
func (o OpenAIOption) WithSystemMessage(msg string) Option {
return func(co *Options) {
co.systemMessage = msg
}
}

func (o OpenAIOption) WithExtraMessages(extraMessages []prompts.MessageFormatter) CreationOption {
return func(co *CreationOptions) {
func (o OpenAIOption) WithExtraMessages(extraMessages []prompts.MessageFormatter) Option {
return func(co *Options) {
co.extraMessages = extraMessages
}
}

0 comments on commit 4746a5d

Please sign in to comment.