Skip to content

Commit

Permalink
Aigentic 27 improve and test agentexecutor (#10)
Browse files Browse the repository at this point in the history
* Added basic test for AgentExecutor

* Added more tests

* Fix test execution

* ToolCall and ToolResult tests

* Refactored AgentExecutor.kt

* Re-enable logging

* Spotless

* Fix example

* Added separate state object

* Spotless

* Cleanup

---------

Co-authored-by: Niels Simonides <[email protected]>
  • Loading branch information
nsmnds and Niels Simonides authored May 15, 2024
1 parent 5f3c9af commit 26aff6f
Show file tree
Hide file tree
Showing 26 changed files with 631 additions and 273 deletions.
2 changes: 1 addition & 1 deletion .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ indent_size = 4
[*.{kt,kts}]
ktlint_code_style = ktlint_official
ktlint_ignore_back_ticked_identifier = true

ktlint_standard = enabled
ktlint_standard_filename = disabled
ktlint_standard_property-naming = disabled

# Experimental rules run by default run on the ktlint code base itself. Experimental rules should not be released if
Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ jobs:
include:
- target: jvmTest
os: ubuntu-latest
- target: linuxX64Test
os: ubuntu-latest
- target: jsTest
os: ubuntu-latest
- target: jsNodeTest
Expand Down
1 change: 1 addition & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ plugins {
}

subprojects {

apply(plugin = "com.diffplug.spotless")
configure<SpotlessExtension> {
kotlin {
Expand Down
39 changes: 35 additions & 4 deletions src/core/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
plugins {
alias(libs.plugins.kotlinMultiplatform)
id("module.publication")
alias(libs.plugins.kotlinx.serialization)
alias(libs.plugins.dokka)
kotlin("plugin.serialization")
alias(libs.plugins.kotest.multiplatform)
id("module.publication")
}

kotlin {
jvm()
linuxX64()
js(IR) {
nodejs()
generateTypeScriptDefinitions()
}

sourceSets {

all {
languageSettings.optIn("kotlinx.coroutines.ExperimentalCoroutinesApi")
}

val commonMain by getting {
dependencies {
api(libs.coroutines.core)
Expand All @@ -24,8 +29,34 @@ kotlin {
}
val commonTest by getting {
dependencies {
implementation(libs.kotlin.test)
implementation(libs.kotest.framework.engine)
implementation(libs.kotest.assertions.core)
implementation(libs.kotest.framework.datatest)
implementation(libs.kotest.property)
}
}
val jvmTest by getting {
dependencies {
implementation(libs.kotest.runner.junit5)
implementation(libs.kotlin.reflect)
implementation(libs.mockk)
}
}
}
}

tasks.named<Test>("jvmTest") {
useJUnitPlatform()
filter {
isFailOnNoMatchingTests = false
}
testLogging {
showExceptions = true
showStandardStreams = true
events = setOf(
org.gradle.api.tasks.testing.logging.TestLogEvent.FAILED,
org.gradle.api.tasks.testing.logging.TestLogEvent.PASSED
)
exceptionFormat = org.gradle.api.tasks.testing.logging.TestExceptionFormat.FULL
}
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
package community.flock.aigentic.core.agent

import community.flock.aigentic.core.agent.prompt.SystemPromptBuilder
import community.flock.aigentic.core.message.Message
import community.flock.aigentic.core.agent.message.SystemPromptBuilder
import community.flock.aigentic.core.agent.tool.finishOrStuckTool
import community.flock.aigentic.core.model.Model
import community.flock.aigentic.core.tool.InternalTool
import community.flock.aigentic.core.tool.Tool
import community.flock.aigentic.core.tool.ToolName
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.asSharedFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.datetime.Clock
import kotlinx.datetime.Instant

data class Task(
val description: String,
Expand All @@ -26,21 +20,6 @@ sealed interface Context {
data class Image(val base64: String) : Context
}

enum class AgentRunningState(val value: String) {
WAITING_TO_START("WAITING_TO_START"),
RUNNING("RUNNING"),
EXECUTING_TOOL("EXECUTING_TOOL"),
WAITING_ON_APPROVAL("WAITING_ON_APPROVAL"),
COMPLETED("COMPLETED"),
STUCK("STUCK"),
}

data class AgentStatus(
var runningState: AgentRunningState = AgentRunningState.WAITING_TO_START,
val startTimestamp: Instant = Clock.System.now(),
var endTimestamp: Instant? = null,
)

data class Agent(
val id: String,
val systemPromptBuilder: SystemPromptBuilder,
Expand All @@ -49,11 +28,5 @@ data class Agent(
val contexts: List<Context>,
val tools: Map<ToolName, Tool>,
) {
internal val messages = MutableSharedFlow<Message>(replay = 100)
internal val status = MutableStateFlow(AgentStatus())
internal val internalTools = mutableMapOf<ToolName, InternalTool<*>>()
internal val internalTools: Map<ToolName, InternalTool<*>> = mapOf(finishOrStuckTool.name to finishOrStuckTool)
}

fun Agent.getMessages() = messages.asSharedFlow()

fun Agent.getStatus() = status.asStateFlow()
Original file line number Diff line number Diff line change
@@ -1,169 +1,103 @@
package community.flock.aigentic.core.agent

import community.flock.aigentic.core.agent.events.toEvents
import community.flock.aigentic.core.agent.tool.FinishReason
import community.flock.aigentic.core.agent.Action.ExecuteTools
import community.flock.aigentic.core.agent.Action.Finished
import community.flock.aigentic.core.agent.Action.Initialize
import community.flock.aigentic.core.agent.Action.ProcessModelResponse
import community.flock.aigentic.core.agent.Action.SendModelRequest
import community.flock.aigentic.core.agent.message.correctionMessage
import community.flock.aigentic.core.agent.state.State
import community.flock.aigentic.core.agent.state.addMessage
import community.flock.aigentic.core.agent.state.addMessages
import community.flock.aigentic.core.agent.state.getStatus
import community.flock.aigentic.core.agent.state.toRun
import community.flock.aigentic.core.agent.tool.FinishedOrStuck
import community.flock.aigentic.core.agent.tool.finishOrStuckTool
import community.flock.aigentic.core.message.Message
import community.flock.aigentic.core.message.Sender
import community.flock.aigentic.core.message.Sender.Aigentic
import community.flock.aigentic.core.message.ToolCall
import community.flock.aigentic.core.message.ToolResultContent
import community.flock.aigentic.core.message.argumentsAsJson
import community.flock.aigentic.core.model.ModelResponse
import community.flock.aigentic.core.tool.Tool
import community.flock.aigentic.core.tool.ToolName
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.async
import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.flow.asFlow
import kotlinx.coroutines.flow.flatMapConcat
import kotlinx.datetime.Clock

data class ToolInterceptorResult(val cancelExecution: Boolean, val reason: String?)

interface ToolInterceptor {
suspend fun intercept(
agent: Agent,
tool: Tool,
toolCall: ToolCall,
): ToolInterceptorResult
}
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.map

suspend fun Agent.run(): FinishedOrStuck =
suspend fun Agent.start(): Run =
coroutineScope {
async {
getMessages().flatMapConcat { it.toEvents().asFlow() }.collect {
println(it.text)
}
}
val state = State()
val logging = async { state.getStatus().map { it.text }.collect(::println) }

AgentExecutor().runAgent(this@run)
executeAction(Initialize(state, this@start)).also {
delay(10) // Allow some time for the logging to finish
logging.cancelAndJoin()
}.toRun()
}

class AgentExecutor(private val toolInterceptors: List<ToolInterceptor> = emptyList()) {
suspend fun runAgent(agent: Agent): FinishedOrStuck {
agent.setRunningState(AgentRunningState.RUNNING)

agent.initialize() // Maybe move to Agent builder?
val modelResponse = agent.sendModelRequest()

val result = CompletableDeferred<FinishedOrStuck>()
processResponse(agent, modelResponse) { result.complete(it) }
val resultState = result.await()

agent.updateStatus {
val endRunningState =
if (resultState.reason is FinishReason.ImStuck) {
AgentRunningState.STUCK
} else {
AgentRunningState.COMPLETED
}
it.copy(
runningState = endRunningState,
endTimestamp = Clock.System.now(),
)
}
return resultState
private suspend fun executeAction(action: Action): Pair<State, FinishedOrStuck> =
when (action) {
is Initialize -> executeAction(action.process(action.state))
is SendModelRequest -> executeAction(action.process(action.state))
is ProcessModelResponse -> executeAction(action.process(action.state))
is ExecuteTools -> executeAction(action.process(action.state))
is Finished -> action.process(action.state)
}

private suspend fun Agent.initialize() {
internalTools[finishOrStuckTool.name] = finishOrStuckTool
messages.emit(systemPromptBuilder.buildSystemPrompt(this))
contexts.map {
when (it) {
is Context.Image -> Message.Image(Sender.Aigentic, it.base64)
is Context.Text -> Message.Text(Sender.Aigentic, it.text)
}
}.forEach { messages.emit(it) }
}

private suspend fun processResponse(
agent: Agent,
response: ModelResponse,
onFinished: (FinishedOrStuck) -> Unit,
) {
val message = response.message
agent.messages.emit(message)

when (message) {
is Message.ToolCalls -> {
val shouldSendNextRequest =
message.toolCalls
.map { toolCall ->
when (toolCall.name) {
finishOrStuckTool.name.value -> {
val finishedOrStuck = finishOrStuckTool.handler(toolCall.argumentsAsJson())
onFinished(finishedOrStuck)
false
}

else -> {
val toolResult = agent.execute(toolCall)
agent.messages.emit(toolResult)
true
}
}
}
.contains(true)

if (shouldSendNextRequest) {
sendToolResponse(agent, onFinished)
}
}
private suspend fun Initialize.process(state: State): Action {
state.addMessages(initializeStartMessages(agent))
return SendModelRequest(state, agent)
}

else -> error("Expected ToolCalls message, got $message")
private suspend fun ProcessModelResponse.process(state: State): Action =
when (responseMessage) {
is Message.ToolCalls -> ExecuteTools(state, agent, responseMessage.toolCalls)
else -> {
state.messages.emit(correctionMessage)
SendModelRequest(state, agent)
}
}

private suspend fun Agent.execute(toolCall: ToolCall): Message.ToolResult {
val functionArgs = toolCall.argumentsAsJson()
val tool = tools[ToolName(toolCall.name)] ?: error("Tool not registered: $toolCall")
private suspend fun SendModelRequest.process(state: State): ProcessModelResponse {
val message = agent.sendModelRequest(state).message
state.addMessage(message)
return ProcessModelResponse(state, agent, message)
}

private fun Finished.process(state: State) = state to finishedOrStuck

val cancelMessage = runInterceptors(this, tool, toolCall)
if (cancelMessage != null) {
setRunningState(AgentRunningState.RUNNING)
return cancelMessage
}
private suspend fun ExecuteTools.process(state: State): Action {
val toolResults = executeToolCalls(agent, toolCalls)
val finished = toolResults.filterIsInstance<ToolExecutionResult.FinishedToolResult>().firstOrNull()

setRunningState(AgentRunningState.EXECUTING_TOOL)
val result = tool.handler(functionArgs)
setRunningState(AgentRunningState.RUNNING)
return Message.ToolResult(toolCall.id, toolCall.name, ToolResultContent(result))
return if (finished != null) {
Finished(state, agent, finished.reason)
} else {
state.addMessages(toolResults.filterIsInstance<ToolExecutionResult.ToolResult>().map { it.message })
SendModelRequest(state, agent)
}
}

private suspend fun runInterceptors(
agent: Agent,
tool: Tool,
toolCall: ToolCall,
): Message.ToolResult? =
toolInterceptors
.map { it.intercept(agent, tool, toolCall) }
.firstOrNull { it.cancelExecution }?.let {
Message.ToolResult(
toolCall.id,
toolCall.name,
ToolResultContent(it.reason ?: "Tool execution blocked by interceptor"),
)
private fun initializeStartMessages(agent: Agent): List<Message> =
listOf(
agent.systemPromptBuilder.buildSystemPrompt(agent),
) +
agent.contexts.map {
when (it) {
is Context.Image -> Message.Image(Aigentic, it.base64)
is Context.Text -> Message.Text(Aigentic, it.text)
}
}

private suspend fun sendToolResponse(
agent: Agent,
onFinished: (FinishedOrStuck) -> Unit,
) {
val response = agent.sendModelRequest()
processResponse(agent, response, onFinished)
}
private suspend fun Agent.sendModelRequest(state: State): ModelResponse =
model.sendRequest(state.messages.replayCache, tools.values.toList() + internalTools.values.toList())

private suspend fun Agent.sendModelRequest(): ModelResponse =
model.sendRequest(messages.replayCache, tools.values.toList() + internalTools.values.toList())
}
private sealed interface Action {
data class Initialize(val state: State, val agent: Agent) : Action

suspend fun Agent.updateStatus(update: (currentStatus: AgentStatus) -> AgentStatus) {
update.invoke(status.value).let {
status.emit(it)
}
}
data class ExecuteTools(val state: State, val agent: Agent, val toolCalls: List<ToolCall>) : Action

data class SendModelRequest(val state: State, val agent: Agent) : Action

data class ProcessModelResponse(val state: State, val agent: Agent, val responseMessage: Message) : Action

suspend fun Agent.setRunningState(state: AgentRunningState) {
updateStatus { status.value.copy(runningState = state) }
data class Finished(val state: State, val agent: Agent, val finishedOrStuck: FinishedOrStuck) : Action
}
Loading

0 comments on commit 26aff6f

Please sign in to comment.