-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Aigentic 27 improve and test agentexecutor (#10)
* Added basic test for AgentExecutor * Added more tests * Fix test execution * ToolCall and ToolResult tests * Refactored AgentExecutor.kt * Re-enable logging * Spotless * Fix example * Added separate state object * Spotless * Cleanup --------- Co-authored-by: Niels Simonides <[email protected]>
- Loading branch information
Showing
26 changed files
with
631 additions
and
273 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
214 changes: 74 additions & 140 deletions
214
src/core/src/commonMain/kotlin/community/flock/aigentic/core/agent/AgentExecutor.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,169 +1,103 @@ | ||
package community.flock.aigentic.core.agent | ||
|
||
import community.flock.aigentic.core.agent.events.toEvents | ||
import community.flock.aigentic.core.agent.tool.FinishReason | ||
import community.flock.aigentic.core.agent.Action.ExecuteTools | ||
import community.flock.aigentic.core.agent.Action.Finished | ||
import community.flock.aigentic.core.agent.Action.Initialize | ||
import community.flock.aigentic.core.agent.Action.ProcessModelResponse | ||
import community.flock.aigentic.core.agent.Action.SendModelRequest | ||
import community.flock.aigentic.core.agent.message.correctionMessage | ||
import community.flock.aigentic.core.agent.state.State | ||
import community.flock.aigentic.core.agent.state.addMessage | ||
import community.flock.aigentic.core.agent.state.addMessages | ||
import community.flock.aigentic.core.agent.state.getStatus | ||
import community.flock.aigentic.core.agent.state.toRun | ||
import community.flock.aigentic.core.agent.tool.FinishedOrStuck | ||
import community.flock.aigentic.core.agent.tool.finishOrStuckTool | ||
import community.flock.aigentic.core.message.Message | ||
import community.flock.aigentic.core.message.Sender | ||
import community.flock.aigentic.core.message.Sender.Aigentic | ||
import community.flock.aigentic.core.message.ToolCall | ||
import community.flock.aigentic.core.message.ToolResultContent | ||
import community.flock.aigentic.core.message.argumentsAsJson | ||
import community.flock.aigentic.core.model.ModelResponse | ||
import community.flock.aigentic.core.tool.Tool | ||
import community.flock.aigentic.core.tool.ToolName | ||
import kotlinx.coroutines.CompletableDeferred | ||
import kotlinx.coroutines.async | ||
import kotlinx.coroutines.cancelAndJoin | ||
import kotlinx.coroutines.coroutineScope | ||
import kotlinx.coroutines.flow.asFlow | ||
import kotlinx.coroutines.flow.flatMapConcat | ||
import kotlinx.datetime.Clock | ||
|
||
data class ToolInterceptorResult(val cancelExecution: Boolean, val reason: String?) | ||
|
||
interface ToolInterceptor { | ||
suspend fun intercept( | ||
agent: Agent, | ||
tool: Tool, | ||
toolCall: ToolCall, | ||
): ToolInterceptorResult | ||
} | ||
import kotlinx.coroutines.delay | ||
import kotlinx.coroutines.flow.map | ||
|
||
suspend fun Agent.run(): FinishedOrStuck = | ||
suspend fun Agent.start(): Run = | ||
coroutineScope { | ||
async { | ||
getMessages().flatMapConcat { it.toEvents().asFlow() }.collect { | ||
println(it.text) | ||
} | ||
} | ||
val state = State() | ||
val logging = async { state.getStatus().map { it.text }.collect(::println) } | ||
|
||
AgentExecutor().runAgent(this@run) | ||
executeAction(Initialize(state, this@start)).also { | ||
delay(10) // Allow some time for the logging to finish | ||
logging.cancelAndJoin() | ||
}.toRun() | ||
} | ||
|
||
class AgentExecutor(private val toolInterceptors: List<ToolInterceptor> = emptyList()) { | ||
suspend fun runAgent(agent: Agent): FinishedOrStuck { | ||
agent.setRunningState(AgentRunningState.RUNNING) | ||
|
||
agent.initialize() // Maybe move to Agent builder? | ||
val modelResponse = agent.sendModelRequest() | ||
|
||
val result = CompletableDeferred<FinishedOrStuck>() | ||
processResponse(agent, modelResponse) { result.complete(it) } | ||
val resultState = result.await() | ||
|
||
agent.updateStatus { | ||
val endRunningState = | ||
if (resultState.reason is FinishReason.ImStuck) { | ||
AgentRunningState.STUCK | ||
} else { | ||
AgentRunningState.COMPLETED | ||
} | ||
it.copy( | ||
runningState = endRunningState, | ||
endTimestamp = Clock.System.now(), | ||
) | ||
} | ||
return resultState | ||
private suspend fun executeAction(action: Action): Pair<State, FinishedOrStuck> = | ||
when (action) { | ||
is Initialize -> executeAction(action.process(action.state)) | ||
is SendModelRequest -> executeAction(action.process(action.state)) | ||
is ProcessModelResponse -> executeAction(action.process(action.state)) | ||
is ExecuteTools -> executeAction(action.process(action.state)) | ||
is Finished -> action.process(action.state) | ||
} | ||
|
||
private suspend fun Agent.initialize() { | ||
internalTools[finishOrStuckTool.name] = finishOrStuckTool | ||
messages.emit(systemPromptBuilder.buildSystemPrompt(this)) | ||
contexts.map { | ||
when (it) { | ||
is Context.Image -> Message.Image(Sender.Aigentic, it.base64) | ||
is Context.Text -> Message.Text(Sender.Aigentic, it.text) | ||
} | ||
}.forEach { messages.emit(it) } | ||
} | ||
|
||
private suspend fun processResponse( | ||
agent: Agent, | ||
response: ModelResponse, | ||
onFinished: (FinishedOrStuck) -> Unit, | ||
) { | ||
val message = response.message | ||
agent.messages.emit(message) | ||
|
||
when (message) { | ||
is Message.ToolCalls -> { | ||
val shouldSendNextRequest = | ||
message.toolCalls | ||
.map { toolCall -> | ||
when (toolCall.name) { | ||
finishOrStuckTool.name.value -> { | ||
val finishedOrStuck = finishOrStuckTool.handler(toolCall.argumentsAsJson()) | ||
onFinished(finishedOrStuck) | ||
false | ||
} | ||
|
||
else -> { | ||
val toolResult = agent.execute(toolCall) | ||
agent.messages.emit(toolResult) | ||
true | ||
} | ||
} | ||
} | ||
.contains(true) | ||
|
||
if (shouldSendNextRequest) { | ||
sendToolResponse(agent, onFinished) | ||
} | ||
} | ||
private suspend fun Initialize.process(state: State): Action { | ||
state.addMessages(initializeStartMessages(agent)) | ||
return SendModelRequest(state, agent) | ||
} | ||
|
||
else -> error("Expected ToolCalls message, got $message") | ||
private suspend fun ProcessModelResponse.process(state: State): Action = | ||
when (responseMessage) { | ||
is Message.ToolCalls -> ExecuteTools(state, agent, responseMessage.toolCalls) | ||
else -> { | ||
state.messages.emit(correctionMessage) | ||
SendModelRequest(state, agent) | ||
} | ||
} | ||
|
||
private suspend fun Agent.execute(toolCall: ToolCall): Message.ToolResult { | ||
val functionArgs = toolCall.argumentsAsJson() | ||
val tool = tools[ToolName(toolCall.name)] ?: error("Tool not registered: $toolCall") | ||
private suspend fun SendModelRequest.process(state: State): ProcessModelResponse { | ||
val message = agent.sendModelRequest(state).message | ||
state.addMessage(message) | ||
return ProcessModelResponse(state, agent, message) | ||
} | ||
|
||
private fun Finished.process(state: State) = state to finishedOrStuck | ||
|
||
val cancelMessage = runInterceptors(this, tool, toolCall) | ||
if (cancelMessage != null) { | ||
setRunningState(AgentRunningState.RUNNING) | ||
return cancelMessage | ||
} | ||
private suspend fun ExecuteTools.process(state: State): Action { | ||
val toolResults = executeToolCalls(agent, toolCalls) | ||
val finished = toolResults.filterIsInstance<ToolExecutionResult.FinishedToolResult>().firstOrNull() | ||
|
||
setRunningState(AgentRunningState.EXECUTING_TOOL) | ||
val result = tool.handler(functionArgs) | ||
setRunningState(AgentRunningState.RUNNING) | ||
return Message.ToolResult(toolCall.id, toolCall.name, ToolResultContent(result)) | ||
return if (finished != null) { | ||
Finished(state, agent, finished.reason) | ||
} else { | ||
state.addMessages(toolResults.filterIsInstance<ToolExecutionResult.ToolResult>().map { it.message }) | ||
SendModelRequest(state, agent) | ||
} | ||
} | ||
|
||
private suspend fun runInterceptors( | ||
agent: Agent, | ||
tool: Tool, | ||
toolCall: ToolCall, | ||
): Message.ToolResult? = | ||
toolInterceptors | ||
.map { it.intercept(agent, tool, toolCall) } | ||
.firstOrNull { it.cancelExecution }?.let { | ||
Message.ToolResult( | ||
toolCall.id, | ||
toolCall.name, | ||
ToolResultContent(it.reason ?: "Tool execution blocked by interceptor"), | ||
) | ||
private fun initializeStartMessages(agent: Agent): List<Message> = | ||
listOf( | ||
agent.systemPromptBuilder.buildSystemPrompt(agent), | ||
) + | ||
agent.contexts.map { | ||
when (it) { | ||
is Context.Image -> Message.Image(Aigentic, it.base64) | ||
is Context.Text -> Message.Text(Aigentic, it.text) | ||
} | ||
} | ||
|
||
private suspend fun sendToolResponse( | ||
agent: Agent, | ||
onFinished: (FinishedOrStuck) -> Unit, | ||
) { | ||
val response = agent.sendModelRequest() | ||
processResponse(agent, response, onFinished) | ||
} | ||
private suspend fun Agent.sendModelRequest(state: State): ModelResponse = | ||
model.sendRequest(state.messages.replayCache, tools.values.toList() + internalTools.values.toList()) | ||
|
||
private suspend fun Agent.sendModelRequest(): ModelResponse = | ||
model.sendRequest(messages.replayCache, tools.values.toList() + internalTools.values.toList()) | ||
} | ||
private sealed interface Action { | ||
data class Initialize(val state: State, val agent: Agent) : Action | ||
|
||
suspend fun Agent.updateStatus(update: (currentStatus: AgentStatus) -> AgentStatus) { | ||
update.invoke(status.value).let { | ||
status.emit(it) | ||
} | ||
} | ||
data class ExecuteTools(val state: State, val agent: Agent, val toolCalls: List<ToolCall>) : Action | ||
|
||
data class SendModelRequest(val state: State, val agent: Agent) : Action | ||
|
||
data class ProcessModelResponse(val state: State, val agent: Agent, val responseMessage: Message) : Action | ||
|
||
suspend fun Agent.setRunningState(state: AgentRunningState) { | ||
updateStatus { status.value.copy(runningState = state) } | ||
data class Finished(val state: State, val agent: Agent, val finishedOrStuck: FinishedOrStuck) : Action | ||
} |
Oops, something went wrong.