Skip to content

Commit

Permalink
PDF support for gemini
Browse files Browse the repository at this point in the history
  • Loading branch information
nsmnds committed Aug 9, 2024
1 parent 433ccb0 commit b306e0a
Show file tree
Hide file tree
Showing 21 changed files with 132 additions and 46 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,4 @@ Aigentic is released under the MIT License.

## Support

For questions, issues, or feature requests, please open an issue on our GitHub repository. Or contact us as [info@flock.community](mailto:info@flock.community?subject=Aigentic)
For questions, issues, or feature requests, please open an issue on our GitHub repository. Or contact us as [info@aigentic.io](mailto:info@aigentic.io?subject=Aigentic)
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ data class Instruction(val text: String)
sealed interface Context {
data class Text(val text: String) : Context

data class ImageUrl(val url: String, val mimeType: MimeType) : Context
data class Url(val url: String, val mimeType: MimeType) : Context

data class ImageBase64(val base64: String, val mimeType: MimeType) : Context
data class Base64(val base64: String, val mimeType: MimeType) : Context
}

data class Agent(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ private fun initializeStartMessages(agent: Agent): List<Message> =
listOf(agent.systemPromptBuilder.buildSystemPrompt(agent)) +
agent.contexts.map {
when (it) {
is Context.ImageUrl -> Message.ImageUrl(sender = Sender.Agent, url = it.url, mimeType = it.mimeType)
is Context.ImageBase64 -> Message.ImageBase64(sender = Sender.Agent, base64Content = it.base64, mimeType = it.mimeType)
is Context.Url -> Message.Url(sender = Sender.Agent, url = it.url, mimeType = it.mimeType)
is Context.Base64 -> Message.Base64(sender = Sender.Agent, base64Content = it.base64, mimeType = it.mimeType)
is Context.Text -> Message.Text(Sender.Agent, it.text)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ sealed interface AgentStatus {
fun Message.toStatus(): List<AgentStatus> =
when (this) {
is Message.SystemPrompt -> emptyList()
is Message.Text, is Message.ImageUrl, is Message.ImageBase64 -> emptyList()
is Message.Text, is Message.Url, is Message.Base64 -> emptyList()
is Message.ToolCalls ->
this.toolCalls.map {
when (it.name) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,16 @@ class ContextConfig : Config<List<Context>> {
Context.Text(text)
.also { contexts.add(it) }

fun ContextConfig.addImageUrl(
fun ContextConfig.addUrl(
url: String,
mimeType: MimeType,
) = Context.ImageUrl(url = url, mimeType = mimeType)
) = Context.Url(url = url, mimeType = mimeType)
.also { contexts.add(it) }

fun ContextConfig.addImageBase64(
fun ContextConfig.addBase64(
base64: String,
mimeType: MimeType,
) = Context.ImageBase64(base64 = base64, mimeType = mimeType)
) = Context.Base64(base64 = base64, mimeType = mimeType)
.also { contexts.add(it) }

override fun build(): List<Context> = contexts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ sealed class Message(
val text: String,
) : Message(sender)

data class ImageUrl(
data class Url(
override val sender: Sender,
val url: String,
val mimeType: MimeType,
) : Message(Sender.Agent)

data class ImageBase64(
data class Base64(
override val sender: Sender,
val base64Content: String,
val mimeType: MimeType,
Expand Down Expand Up @@ -54,6 +54,7 @@ enum class MimeType(val value: String) {
WEBP("image/webp"),
HEIC("image/heic"),
HEIF("image/heif"),
PDF("application/pdf"),
}

@JvmInline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class AgentExecutorTest : DescribeSpec({
task("Execute some task") {}
context {
addText(expectedTextContext)
addImageBase64(base64 = expectedImageContextBase64, mimeType = expectedImageContextMimeType)
addBase64(base64 = expectedImageContextBase64, mimeType = expectedImageContextMimeType)
}
addTool(mockk(relaxed = true))
}
Expand All @@ -159,7 +159,7 @@ class AgentExecutorTest : DescribeSpec({
messages.drop(1).take(2) shouldBe
listOf(
Message.Text(Sender.Agent, expectedTextContext),
Message.ImageBase64(
Message.Base64(
sender = Sender.Agent,
base64Content = expectedImageContextBase64,
mimeType = expectedImageContextMimeType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ class AgentConfigTest : DescribeSpec({
task("Task description") {}
context {
addText("Some text")
addImageUrl("https://example.com/image.jpg", MimeType.JPEG)
addUrl("https://example.com/image.jpg", MimeType.JPEG)
}
addTool(mockk(relaxed = true))
}.run {
contexts.size shouldBe 2
contexts.first() shouldBe Context.Text("Some text")
contexts.last() shouldBe Context.ImageUrl("https://example.com/image.jpg", MimeType.JPEG)
contexts.last() shouldBe Context.Url("https://example.com/image.jpg", MimeType.JPEG)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ suspend fun runItemCategorizeExample(
task("Identify all items in the image and save each individual item") {}
addTool(saveItemTool)
context {
addImageBase64(base64Image, MimeType.JPEG)
addBase64(base64Image, MimeType.JPEG)
}
}.start()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package community.flock.aigentic.example

import community.flock.aigentic.core.agent.Run
import community.flock.aigentic.core.agent.start
import community.flock.aigentic.core.dsl.AgentConfig
import community.flock.aigentic.core.dsl.agent
import community.flock.aigentic.core.message.MimeType
import community.flock.aigentic.core.tool.Parameter
import community.flock.aigentic.core.tool.ParameterType.Primitive
import community.flock.aigentic.core.tool.Tool
import community.flock.aigentic.core.tool.ToolName
import community.flock.aigentic.core.tool.getStringValue
import kotlinx.serialization.json.JsonObject

val savePdfSummary =
object : Tool {
val title =
Parameter.Primitive(
name = "title",
description = null,
isRequired = true,
type = Primitive.String,
)

val mainPointsParameter =
Parameter.Complex.Array(
name = "mainPoints",
description = "List of main points mentioned in the article",
isRequired = true,
itemDefinition =
Parameter.Primitive(
name = "mainPoint",
description = null,
isRequired = true,
type = Primitive.String,
),
)

override val name = ToolName("savePdfSummary")
override val description = "Saves the summary of the PDF"
override val parameters = listOf(title, mainPointsParameter)

override val handler: suspend (JsonObject) -> String = { arguments ->

val message = title.getStringValue(arguments)
"Successfully saved: '$message' "
}
}

suspend fun pdfSummaryAgent(
pdfBase64: String,
configureModel: AgentConfig.() -> Unit,
): Run {
val run =
agent {
configureModel()
task("Summarize the content of a PDF") {
addInstruction("Give the summary a comprehensive title")
addInstruction("Please provide list of main points")
}
addTool(savePdfSummary)
context {
addBase64(pdfBase64, MimeType.PDF)
}
}.start()

return run
}
Binary file not shown.

This file was deleted.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,15 @@ fun main() {
RunExamples.KOTLIN_MESSAGE_AGENT -> runKotlinMessageAgentExample(AgentConfig::configureModel)
RunExamples.ITEM_CATEGORIZE_AGENT ->
runItemCategorizeExample(
FileReader.readFile("/base64Image.txt"),
FileReader.readFileBase64("/table-items.png"),
AgentConfig::configureModel,
)
/**
* PDF is currently on supported by Gemini
*/
RunExamples.PDF_SUMMARY_AGENT ->
pdfSummaryAgent(
FileReader.readFileBase64("/aigentic.pdf"),
AgentConfig::configureModel,
)
}.also {
Expand Down Expand Up @@ -76,6 +84,7 @@ enum class RunExamples {
ADMINISTRATIVE_AGENT,
KOTLIN_MESSAGE_AGENT,
ITEM_CATEGORIZE_AGENT,
PDF_SUMMARY_AGENT,
}

enum class Provider {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
package community.flock.aigentic.example

import java.util.Base64

object FileReader {
fun readFile(path: String): String {
return this::class.java.getResource(path)!!.readText(Charsets.UTF_8).trim()

fun readFileBase64(path: String): String {
val inputStream =
this::class.java.getResource(path)?.openStream()
?: throw IllegalArgumentException("Resource not found: $path")
val bytes = inputStream.use { it.readBytes() }
return Base64.getEncoder().encodeToString(bytes)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ import community.flock.aigentic.core.message.Message
import community.flock.aigentic.core.message.MimeType
import community.flock.aigentic.core.message.Sender
import community.flock.aigentic.core.message.ToolCall
import community.flock.aigentic.gateway.wirespec.Base64MessageDto
import community.flock.aigentic.gateway.wirespec.ConfigDto
import community.flock.aigentic.gateway.wirespec.FatalResultDto
import community.flock.aigentic.gateway.wirespec.FinishedResultDto
import community.flock.aigentic.gateway.wirespec.ImageBase64MessageDto
import community.flock.aigentic.gateway.wirespec.ImageUrlMessageDto
import community.flock.aigentic.gateway.wirespec.MessageDto
import community.flock.aigentic.gateway.wirespec.MimeTypeDto
import community.flock.aigentic.gateway.wirespec.ModelRequestInfoDto
Expand All @@ -26,6 +25,7 @@ import community.flock.aigentic.gateway.wirespec.ToolCallDto
import community.flock.aigentic.gateway.wirespec.ToolCallsMessageDto
import community.flock.aigentic.gateway.wirespec.ToolDto
import community.flock.aigentic.gateway.wirespec.ToolResultMessageDto
import community.flock.aigentic.gateway.wirespec.UrlMessageDto
import community.flock.aigentic.providers.jsonschema.emitPropertiesAndRequired
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
Expand Down Expand Up @@ -81,16 +81,16 @@ private fun Sender.toDto(): SenderDto =

private fun Message.toDto(): MessageDto =
when (this) {
is Message.ImageBase64 ->
ImageBase64MessageDto(
is Message.Base64 ->
Base64MessageDto(
createdAt = createdAt.toString(),
sender = sender.toDto(),
base64Content = base64Content,
mimeType = mimeType.toDto(),
)

is Message.ImageUrl ->
ImageUrlMessageDto(
is Message.Url ->
UrlMessageDto(
createdAt = createdAt.toString(),
sender = sender.toDto(),
url = url,
Expand Down Expand Up @@ -142,6 +142,7 @@ private fun MimeType.toDto(): MimeTypeDto =
MimeType.WEBP -> MimeTypeDto.IMAGE_WEBP
MimeType.HEIC -> MimeTypeDto.IMAGE_HEIC
MimeType.HEIF -> MimeTypeDto.IMAGE_HEIF
MimeType.PDF -> MimeTypeDto.APPLICATION_PDF
}

private fun Result.toDto() =
Expand Down
9 changes: 5 additions & 4 deletions src/platform/wirespec/gateway.ws
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ enum MimeTypeDto {
IMAGE_PNG,
IMAGE_WEBP,
IMAGE_HEIC,
IMAGE_HEIF
IMAGE_HEIF,
APPLICATION_PDF
}

type SystemPromptMessageDto {
Expand All @@ -72,14 +73,14 @@ type TextMessageDto {
text: String
}

type ImageUrlMessageDto {
type UrlMessageDto {
createdAt: String,
sender: SenderDto,
url: String,
mimeType: MimeTypeDto
}

type ImageBase64MessageDto {
type Base64MessageDto {
createdAt: String,
sender: SenderDto,
base64Content: String,
Expand All @@ -106,7 +107,7 @@ type ToolResultMessageDto {
response: String
}

type MessageDto = SystemPromptMessageDto | TextMessageDto | ImageUrlMessageDto | ImageBase64MessageDto | ToolCallsMessageDto | ToolResultMessageDto
type MessageDto = SystemPromptMessageDto | TextMessageDto | UrlMessageDto | Base64MessageDto | ToolCallsMessageDto | ToolResultMessageDto

type GatewayClientErrorDto {
message: String
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ internal fun createGenerateContentRequest(
contents =
messages.map { message ->
when (message) {
is Message.ImageUrl ->
is Message.Url ->
listOf(
Part.FileDataPart(FileDataContent(mimeType = message.mimeType.value, fileUri = message.url)),
)
is Message.ImageBase64 ->
is Message.Base64 ->
listOf(
Part.Blob(BlobContent(mimeType = message.mimeType.value, data = formatBase64Content(message))),
)
Expand Down Expand Up @@ -86,7 +86,7 @@ internal fun createGenerateContentRequest(
),
)

private fun formatBase64Content(message: Message.ImageBase64) = message.base64Content.substringAfter("base64,")
private fun formatBase64Content(message: Message.Base64) = message.base64Content.substringAfter("base64,")

private fun getSystemInstruction(messages: List<Message>): Content =
messages.filterIsInstance<Message.SystemPrompt>().map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ class GeminiRequestMapperKtTest : DescribeSpec({
it("Should not format when raw base64 content is provided") {
val base64Content = "iVBORw0KGgoAAA=="
val mimeType = MimeType.PNG
val imageBase64Message = Message.ImageBase64(Sender.Model, base64Content, mimeType)
val base64Message = Message.Base64(Sender.Model, base64Content, mimeType)

createGenerateContentRequest(listOf(imageBase64Message), emptyList()).contents[0].parts[0]
createGenerateContentRequest(listOf(base64Message), emptyList()).contents[0].parts[0]
.shouldBeInstanceOf<Part.Blob>().run {
this.inlineData shouldBe BlobContent(mimeType = mimeType.value, data = base64Content)
}
Expand All @@ -27,9 +27,9 @@ class GeminiRequestMapperKtTest : DescribeSpec({
it("should format when base64 data url is provided") {
val base64Content = ""
val mimeType = MimeType.PNG
val imageBase64Message = Message.ImageBase64(Sender.Model, base64Content, mimeType)
val base64Message = Message.Base64(Sender.Model, base64Content, mimeType)

createGenerateContentRequest(listOf(imageBase64Message), emptyList()).contents[0].parts[0]
createGenerateContentRequest(listOf(base64Message), emptyList()).contents[0].parts[0]
.shouldBeInstanceOf<Part.Blob>().run {
this.inlineData shouldBe BlobContent(mimeType = mimeType.value, data = "iVBORw0KGgoAAA==")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ object OpenAIMapper {
return when (this) {
is Message.SystemPrompt -> ChatMessage(role, prompt)
is Message.Text -> ChatMessage(role, text)
is Message.ImageUrl -> ChatMessage(role = role, listOf(ImagePart(url)))
is Message.ImageBase64 -> ChatMessage(role = role, listOf(ImagePart(formatDataUrl())))
is Message.Url -> ChatMessage(role = role, listOf(ImagePart(url)))
is Message.Base64 -> ChatMessage(role = role, listOf(ImagePart(formatDataUrl())))
is Message.ToolCalls ->
ChatMessage(
role = role,
Expand All @@ -89,7 +89,7 @@ object OpenAIMapper {
}
}

private fun Message.ImageBase64.formatDataUrl(): String =
private fun Message.Base64.formatDataUrl(): String =
base64Content.takeIf { it.startsWith("data:") }
?: "data:${mimeType.value};base64,$base64Content"

Expand All @@ -98,7 +98,7 @@ object OpenAIMapper {
is Message.SystemPrompt -> ChatRole.System
is Message.ToolCalls -> ChatRole.Assistant
is Message.ToolResult -> ChatRole.Tool
is Message.ImageUrl, is Message.ImageBase64, is Message.Text -> mapRole()
is Message.Url, is Message.Base64, is Message.Text -> mapRole()
}

private fun Message.mapRole() =
Expand Down
Loading

0 comments on commit b306e0a

Please sign in to comment.