Skip to content

Commit

Permalink
PDF support for gemini (#23)
Browse files Browse the repository at this point in the history
* PDF support for gemini

* Formatting

* Updated PDF example
  • Loading branch information
nsmnds authored Aug 9, 2024
1 parent 433ccb0 commit fdddcbd
Show file tree
Hide file tree
Showing 22 changed files with 151 additions and 47 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,4 @@ Aigentic is released under the MIT License.

## Support

For questions, issues, or feature requests, please open an issue on our GitHub repository. Or contact us as [info@flock.community](mailto:info@flock.community?subject=Aigentic)
For questions, issues, or feature requests, please open an issue on our GitHub repository. Or contact us as [info@aigentic.io](mailto:info@aigentic.io?subject=Aigentic)
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ data class Instruction(val text: String)
sealed interface Context {
data class Text(val text: String) : Context

data class ImageUrl(val url: String, val mimeType: MimeType) : Context
data class Url(val url: String, val mimeType: MimeType) : Context

data class ImageBase64(val base64: String, val mimeType: MimeType) : Context
data class Base64(val base64: String, val mimeType: MimeType) : Context
}

data class Agent(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ private fun initializeStartMessages(agent: Agent): List<Message> =
listOf(agent.systemPromptBuilder.buildSystemPrompt(agent)) +
agent.contexts.map {
when (it) {
is Context.ImageUrl -> Message.ImageUrl(sender = Sender.Agent, url = it.url, mimeType = it.mimeType)
is Context.ImageBase64 -> Message.ImageBase64(sender = Sender.Agent, base64Content = it.base64, mimeType = it.mimeType)
is Context.Url -> Message.Url(sender = Sender.Agent, url = it.url, mimeType = it.mimeType)
is Context.Base64 -> Message.Base64(sender = Sender.Agent, base64Content = it.base64, mimeType = it.mimeType)
is Context.Text -> Message.Text(Sender.Agent, it.text)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ sealed interface AgentStatus {
fun Message.toStatus(): List<AgentStatus> =
when (this) {
is Message.SystemPrompt -> emptyList()
is Message.Text, is Message.ImageUrl, is Message.ImageBase64 -> emptyList()
is Message.Text, is Message.Url, is Message.Base64 -> emptyList()
is Message.ToolCalls ->
this.toolCalls.map {
when (it.name) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,16 @@ class ContextConfig : Config<List<Context>> {
Context.Text(text)
.also { contexts.add(it) }

fun ContextConfig.addImageUrl(
fun ContextConfig.addUrl(
url: String,
mimeType: MimeType,
) = Context.ImageUrl(url = url, mimeType = mimeType)
) = Context.Url(url = url, mimeType = mimeType)
.also { contexts.add(it) }

fun ContextConfig.addImageBase64(
fun ContextConfig.addBase64(
base64: String,
mimeType: MimeType,
) = Context.ImageBase64(base64 = base64, mimeType = mimeType)
) = Context.Base64(base64 = base64, mimeType = mimeType)
.also { contexts.add(it) }

override fun build(): List<Context> = contexts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ sealed class Message(
val text: String,
) : Message(sender)

data class ImageUrl(
data class Url(
override val sender: Sender,
val url: String,
val mimeType: MimeType,
) : Message(Sender.Agent)

data class ImageBase64(
data class Base64(
override val sender: Sender,
val base64Content: String,
val mimeType: MimeType,
Expand Down Expand Up @@ -54,6 +54,7 @@ enum class MimeType(val value: String) {
WEBP("image/webp"),
HEIC("image/heic"),
HEIF("image/heif"),
PDF("application/pdf"),
}

@JvmInline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.decodeFromJsonElement
import kotlinx.serialization.json.int
import kotlinx.serialization.json.jsonArray
import kotlinx.serialization.json.jsonObject
import kotlinx.serialization.json.jsonPrimitive
import kotlin.jvm.JvmInline
Expand All @@ -17,6 +18,11 @@ inline fun <reified T : Any> Parameter.Complex.Object.getObject(arguments: JsonO
return Json.decodeFromJsonElement(arg.jsonObject)
}

inline fun <reified T : Any> Parameter.Complex.Array.getItems(arguments: JsonObject): List<T> =
arguments.getValue(name).jsonArray.map {
Json.decodeFromJsonElement<T>(it)
}

sealed class Parameter(
open val name: String,
open val description: String?,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class AgentExecutorTest : DescribeSpec({
task("Execute some task") {}
context {
addText(expectedTextContext)
addImageBase64(base64 = expectedImageContextBase64, mimeType = expectedImageContextMimeType)
addBase64(base64 = expectedImageContextBase64, mimeType = expectedImageContextMimeType)
}
addTool(mockk(relaxed = true))
}
Expand All @@ -159,7 +159,7 @@ class AgentExecutorTest : DescribeSpec({
messages.drop(1).take(2) shouldBe
listOf(
Message.Text(Sender.Agent, expectedTextContext),
Message.ImageBase64(
Message.Base64(
sender = Sender.Agent,
base64Content = expectedImageContextBase64,
mimeType = expectedImageContextMimeType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ class AgentConfigTest : DescribeSpec({
task("Task description") {}
context {
addText("Some text")
addImageUrl("https://example.com/image.jpg", MimeType.JPEG)
addUrl("https://example.com/image.jpg", MimeType.JPEG)
}
addTool(mockk(relaxed = true))
}.run {
contexts.size shouldBe 2
contexts.first() shouldBe Context.Text("Some text")
contexts.last() shouldBe Context.ImageUrl("https://example.com/image.jpg", MimeType.JPEG)
contexts.last() shouldBe Context.Url("https://example.com/image.jpg", MimeType.JPEG)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ suspend fun runItemCategorizeExample(
task("Identify all items in the image and save each individual item") {}
addTool(saveItemTool)
context {
addImageBase64(base64Image, MimeType.JPEG)
addBase64(base64Image, MimeType.JPEG)
}
}.start()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package community.flock.aigentic.example

import community.flock.aigentic.core.agent.Run
import community.flock.aigentic.core.agent.start
import community.flock.aigentic.core.dsl.AgentConfig
import community.flock.aigentic.core.dsl.agent
import community.flock.aigentic.core.message.MimeType
import community.flock.aigentic.core.tool.Parameter
import community.flock.aigentic.core.tool.ParameterType.Primitive
import community.flock.aigentic.core.tool.Tool
import community.flock.aigentic.core.tool.ToolName
import community.flock.aigentic.core.tool.getItems
import kotlinx.serialization.Serializable
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject

val saveInvoiceComponents =
object : Tool {
val invoiceComponents =
Parameter.Complex.Array(
name = "components",
description = "A list of the invoice components like number, date, customer_number, etc",
isRequired = true,
itemDefinition =
Parameter.Complex.Object(
name = "component",
description = null,
isRequired = true,
parameters =
listOf(
Parameter.Primitive(
name = "name",
description = "The name of the invoice component e.g. number or date",
isRequired = true,
type = Primitive.String,
),
Parameter.Primitive(
name = "value",
description = "The value of the component",
isRequired = true,
type = Primitive.String,
),
),
),
)

override val name = ToolName("saveInvoiceComponents")
override val description = "Saves the individual invoice components"
override val parameters = listOf(invoiceComponents)

override val handler: suspend (JsonObject) -> String = { arguments ->
val components = invoiceComponents.getItems<InvoiceComponent>(arguments)
Json.encodeToString("Saved ${components.size} invoice components successfully")
}
}

@Serializable
data class InvoiceComponent(
val name: String,
val value: String,
)

suspend fun invoiceExtractorAgent(
invoicePdfBase64: String,
configureModel: AgentConfig.() -> Unit,
): Run {
val run =
agent {
configureModel()
task("Extract the different invoice components") {
addInstruction("Please provide list of the invoice components")
}
addTool(saveInvoiceComponents)
context {
addBase64(invoicePdfBase64, MimeType.PDF)
}
}.start()

return run
}

This file was deleted.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ private val geminiKey by lazy {
}

// Set the active example and provider here
val activeRunExample = RunExamples.KOTLIN_MESSAGE_AGENT
val activeRunExample = RunExamples.INVOICE_EXTRACTOR_AGENT
val activeProvider = Provider.GEMINI

fun main() {
Expand All @@ -34,7 +34,15 @@ fun main() {
RunExamples.KOTLIN_MESSAGE_AGENT -> runKotlinMessageAgentExample(AgentConfig::configureModel)
RunExamples.ITEM_CATEGORIZE_AGENT ->
runItemCategorizeExample(
FileReader.readFile("/base64Image.txt"),
FileReader.readFileBase64("/table-items.png"),
AgentConfig::configureModel,
)
/**
* PDF is currently on supported by Gemini
*/
RunExamples.INVOICE_EXTRACTOR_AGENT ->
invoiceExtractorAgent(
FileReader.readFileBase64("/test-invoice.pdf"),
AgentConfig::configureModel,
)
}.also {
Expand Down Expand Up @@ -76,6 +84,7 @@ enum class RunExamples {
ADMINISTRATIVE_AGENT,
KOTLIN_MESSAGE_AGENT,
ITEM_CATEGORIZE_AGENT,
INVOICE_EXTRACTOR_AGENT,
}

enum class Provider {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
package community.flock.aigentic.example

import java.util.Base64

object FileReader {
fun readFile(path: String): String {
return this::class.java.getResource(path)!!.readText(Charsets.UTF_8).trim()
fun readFileBase64(path: String): String {
val inputStream =
this::class.java.getResource(path)?.openStream()
?: throw IllegalArgumentException("Resource not found: $path")
val bytes = inputStream.use { it.readBytes() }
return Base64.getEncoder().encodeToString(bytes)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ import community.flock.aigentic.core.message.Message
import community.flock.aigentic.core.message.MimeType
import community.flock.aigentic.core.message.Sender
import community.flock.aigentic.core.message.ToolCall
import community.flock.aigentic.gateway.wirespec.Base64MessageDto
import community.flock.aigentic.gateway.wirespec.ConfigDto
import community.flock.aigentic.gateway.wirespec.FatalResultDto
import community.flock.aigentic.gateway.wirespec.FinishedResultDto
import community.flock.aigentic.gateway.wirespec.ImageBase64MessageDto
import community.flock.aigentic.gateway.wirespec.ImageUrlMessageDto
import community.flock.aigentic.gateway.wirespec.MessageDto
import community.flock.aigentic.gateway.wirespec.MimeTypeDto
import community.flock.aigentic.gateway.wirespec.ModelRequestInfoDto
Expand All @@ -26,6 +25,7 @@ import community.flock.aigentic.gateway.wirespec.ToolCallDto
import community.flock.aigentic.gateway.wirespec.ToolCallsMessageDto
import community.flock.aigentic.gateway.wirespec.ToolDto
import community.flock.aigentic.gateway.wirespec.ToolResultMessageDto
import community.flock.aigentic.gateway.wirespec.UrlMessageDto
import community.flock.aigentic.providers.jsonschema.emitPropertiesAndRequired
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
Expand Down Expand Up @@ -81,16 +81,16 @@ private fun Sender.toDto(): SenderDto =

private fun Message.toDto(): MessageDto =
when (this) {
is Message.ImageBase64 ->
ImageBase64MessageDto(
is Message.Base64 ->
Base64MessageDto(
createdAt = createdAt.toString(),
sender = sender.toDto(),
base64Content = base64Content,
mimeType = mimeType.toDto(),
)

is Message.ImageUrl ->
ImageUrlMessageDto(
is Message.Url ->
UrlMessageDto(
createdAt = createdAt.toString(),
sender = sender.toDto(),
url = url,
Expand Down Expand Up @@ -142,6 +142,7 @@ private fun MimeType.toDto(): MimeTypeDto =
MimeType.WEBP -> MimeTypeDto.IMAGE_WEBP
MimeType.HEIC -> MimeTypeDto.IMAGE_HEIC
MimeType.HEIF -> MimeTypeDto.IMAGE_HEIF
MimeType.PDF -> MimeTypeDto.APPLICATION_PDF
}

private fun Result.toDto() =
Expand Down
9 changes: 5 additions & 4 deletions src/platform/wirespec/gateway.ws
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ enum MimeTypeDto {
IMAGE_PNG,
IMAGE_WEBP,
IMAGE_HEIC,
IMAGE_HEIF
IMAGE_HEIF,
APPLICATION_PDF
}

type SystemPromptMessageDto {
Expand All @@ -72,14 +73,14 @@ type TextMessageDto {
text: String
}

type ImageUrlMessageDto {
type UrlMessageDto {
createdAt: String,
sender: SenderDto,
url: String,
mimeType: MimeTypeDto
}

type ImageBase64MessageDto {
type Base64MessageDto {
createdAt: String,
sender: SenderDto,
base64Content: String,
Expand All @@ -106,7 +107,7 @@ type ToolResultMessageDto {
response: String
}

type MessageDto = SystemPromptMessageDto | TextMessageDto | ImageUrlMessageDto | ImageBase64MessageDto | ToolCallsMessageDto | ToolResultMessageDto
type MessageDto = SystemPromptMessageDto | TextMessageDto | UrlMessageDto | Base64MessageDto | ToolCallsMessageDto | ToolResultMessageDto

type GatewayClientErrorDto {
message: String
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ internal fun createGenerateContentRequest(
contents =
messages.map { message ->
when (message) {
is Message.ImageUrl ->
is Message.Url ->
listOf(
Part.FileDataPart(FileDataContent(mimeType = message.mimeType.value, fileUri = message.url)),
)
is Message.ImageBase64 ->
is Message.Base64 ->
listOf(
Part.Blob(BlobContent(mimeType = message.mimeType.value, data = formatBase64Content(message))),
)
Expand Down Expand Up @@ -86,7 +86,7 @@ internal fun createGenerateContentRequest(
),
)

private fun formatBase64Content(message: Message.ImageBase64) = message.base64Content.substringAfter("base64,")
private fun formatBase64Content(message: Message.Base64) = message.base64Content.substringAfter("base64,")

private fun getSystemInstruction(messages: List<Message>): Content =
messages.filterIsInstance<Message.SystemPrompt>().map {
Expand Down
Loading

0 comments on commit fdddcbd

Please sign in to comment.