diff --git a/ollama-client/ollama-client-core/src/jvmTest/kotlin/com/tddworks/ollama/api/chat/OllamaChatResponseTest.kt b/ollama-client/ollama-client-core/src/jvmTest/kotlin/com/tddworks/ollama/api/chat/OllamaChatResponseTest.kt index 5736f17..1302ecb 100644 --- a/ollama-client/ollama-client-core/src/jvmTest/kotlin/com/tddworks/ollama/api/chat/OllamaChatResponseTest.kt +++ b/ollama-client/ollama-client-core/src/jvmTest/kotlin/com/tddworks/ollama/api/chat/OllamaChatResponseTest.kt @@ -5,6 +5,7 @@ import org.junit.jupiter.api.Assertions.* import org.junit.jupiter.api.Test class OllamaChatResponseTest { + @Test fun `should decode response to non-streaming OllamaChatResponse`() { val response = """ diff --git a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/ollama/api/chat/api/Extensions.kt b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/ollama/api/chat/api/Extensions.kt index 8313f71..29e133c 100644 --- a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/ollama/api/chat/api/Extensions.kt +++ b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/ollama/api/chat/api/Extensions.kt @@ -16,25 +16,17 @@ import kotlinx.serialization.ExperimentalSerializationApi fun OllamaChatResponse.toOpenAIChatCompletion(): ChatCompletion { - return ChatCompletion( - id = createdAt, + return ChatCompletion(id = createdAt, created = 1L, model = model, - choices = listOf( - ChatChoice( - message = AssistantMessage( - content = message?.content ?: "", - role = when (message?.role) { - "user" -> Role.User - "assistant" -> Role.Assistant - "system" -> Role.System - else -> throw IllegalArgumentException("Unknown role: ${message?.role}") - } - ), - index = 0, + choices = message?.let { + listOf( + ChatChoice( + message = ChatMessage.assistant(it.content), + index = 0, + ) ) - ) - ) + } ?: emptyList()) } fun OllamaChatResponse.toOpenAIChatCompletionChunk(): ChatCompletionChunk { @@ -49,8 +41,7 @@ fun OllamaChatResponse.toOpenAIChatCompletionChunk(): ChatCompletionChunk { ) ) - return ChatCompletionChunk( - id = id, + return ChatCompletionChunk(id = id, `object` = "ollama-chunk", created = created, model = model, @@ -60,31 +51,27 @@ fun OllamaChatResponse.toOpenAIChatCompletionChunk(): ChatCompletionChunk { content = message?.content, ) ) - } - ) + }) } @OptIn(ExperimentalSerializationApi::class) fun ChatCompletionRequest.toOllamaChatRequest(): OllamaChatRequest { - return OllamaChatRequest( - model = model.value, - messages = messages.map { - OllamaChatMessage( - role = when (it.role) { - Role.User -> "user" - Role.Assistant -> "assistant" - Role.System -> "system" - else -> throw IllegalArgumentException("Unknown role: ${it.role}") - }, - content = when (it) { - is UserMessage -> it.content - is AssistantMessage -> it.content - is SystemMessage -> it.content - else -> throw IllegalArgumentException("Unknown message type: $it") - }, - ) - } - ) + return OllamaChatRequest(model = model.value, messages = messages.map { + OllamaChatMessage( + role = when (it.role) { + Role.User -> "user" + Role.Assistant -> "assistant" + Role.System -> "system" + else -> throw IllegalArgumentException("Unknown role: ${it.role}") + }, + content = when (it) { + is UserMessage -> it.content + is AssistantMessage -> it.content + is SystemMessage -> it.content + else -> throw IllegalArgumentException("Unknown message type: $it") + }, + ) + }) } /** @@ -98,12 +85,9 @@ fun CompletionRequest.toOllamaGenerateRequest(): OllamaGenerateRequest { maxTokens?.let { options["num_predict"] = it } stop?.let { options["stop"] = it.split(",").toTypedArray() } return OllamaGenerateRequest( - model = model.value, - prompt = prompt, - stream = stream ?: false, + model = model.value, prompt = prompt, stream = stream ?: false, // Looks only here can adapt the raw option - raw = (streamOptions?.get("raw") ?: false) as Boolean, - options = options + raw = (streamOptions?.get("raw") ?: false) as Boolean, options = options ) } @@ -111,21 +95,17 @@ fun CompletionRequest.toOllamaGenerateRequest(): OllamaGenerateRequest { * Convert OllamaGenerateResponse to OpenAI Completion */ fun OllamaGenerateResponse.toOpenAICompletion(): Completion { - return Completion( - id = createdAt, + return Completion(id = createdAt, model = model, created = 1, choices = listOf( CompletionChoice( - text = response, - index = 0, - finishReason = doneReason ?: "" + text = response, index = 0, finishReason = doneReason ?: "" ) ), usage = Usage( promptTokens = promptEvalCount, completionTokens = evalCount, - totalTokens = evalCount?.let { promptEvalCount?.plus(it) } - ) - ) + totalTokens = evalCount?.let { promptEvalCount?.plus(it) ?: it } ?: 0, + )) } \ No newline at end of file diff --git a/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/ollama/api/chat/api/ExtensionsTest.kt b/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/ollama/api/chat/api/ExtensionsTest.kt index a0bd8b3..a228806 100644 --- a/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/ollama/api/chat/api/ExtensionsTest.kt +++ b/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/ollama/api/chat/api/ExtensionsTest.kt @@ -1,5 +1,6 @@ package com.tddworks.ollama.api.chat.api +import com.tddworks.common.network.api.ktor.api.AnySerial import com.tddworks.ollama.api.OllamaModel import com.tddworks.ollama.api.chat.OllamaChatMessage import com.tddworks.ollama.api.chat.OllamaChatResponse @@ -7,7 +8,9 @@ import com.tddworks.ollama.api.generate.OllamaGenerateResponse import com.tddworks.openai.api.chat.api.ChatCompletionRequest import com.tddworks.openai.api.chat.api.ChatMessage import com.tddworks.openai.api.chat.api.Model +import com.tddworks.openai.api.chat.api.Role import com.tddworks.openai.api.legacy.completions.api.CompletionRequest +import com.tddworks.openai.api.legacy.completions.api.Usage import kotlinx.serialization.ExperimentalSerializationApi import org.junit.jupiter.api.Assertions.* import org.junit.jupiter.api.Test @@ -16,7 +19,24 @@ import org.junit.jupiter.api.Test class ExtensionsTest { @Test - fun `should convert CompletionRequest to OllamaGenerateRequest`() { + fun `should convert CompletionRequest to OllamaGenerateRequest with required fields`() { + val completionRequest = CompletionRequest( + model = Model(OllamaModel.CODE_LLAMA.value), + prompt = "Once upon a time", + ) + + val ollamaGenerateRequest = completionRequest.toOllamaGenerateRequest() + assertEquals("codellama", ollamaGenerateRequest.model) + assertEquals("Once upon a time", ollamaGenerateRequest.prompt) + assertFalse(ollamaGenerateRequest.stream) + + assertFalse(ollamaGenerateRequest.raw) + + assertEquals(emptyMap(), ollamaGenerateRequest.options) + } + + @Test + fun `should convert CompletionRequest to OllamaGenerateRequest with all fields`() { val completionRequest = CompletionRequest( model = Model(OllamaModel.CODE_LLAMA.value), prompt = "Once upon a time", @@ -48,7 +68,55 @@ class ExtensionsTest { } @Test - fun `should convert OllamaGenerateResponse to OpenAICompletion`() { + fun `should convert OllamaGenerateResponse to OpenAICompletion with required fields`() { + val ollamaGenerateResponse = OllamaGenerateResponse( + model = "some-model", + createdAt = "createdAt", + response = "response", + done = false, + ) + val openAICompletion = ollamaGenerateResponse.toOpenAICompletion() + assertEquals("createdAt", openAICompletion.id) + assertEquals(1, openAICompletion.created) + assertEquals("some-model", openAICompletion.model) + assertEquals(1, openAICompletion.choices.size) + assertEquals("response", openAICompletion.choices[0].text) + assertEquals(0, openAICompletion.choices[0].index) + assertEquals("", openAICompletion.choices[0].finishReason) + assertEquals(Usage(totalTokens = 0), openAICompletion.usage) + } + + @Test + fun `should convert OllamaGenerateResponse to OpenAICompletion without promptEvalCount`() { + val ollamaGenerateResponse = OllamaGenerateResponse( + model = "some-model", + createdAt = "createdAt", + response = "response", + done = false, + evalCount = 10, + evalDuration = 1000, + loadDuration = 1000, + promptEvalDuration = 1000, + ) + val openAICompletion = ollamaGenerateResponse.toOpenAICompletion() + assertEquals("createdAt", openAICompletion.id) + assertEquals(1, openAICompletion.created) + assertEquals("some-model", openAICompletion.model) + assertEquals(1, openAICompletion.choices.size) + assertEquals("response", openAICompletion.choices[0].text) + assertEquals(0, openAICompletion.choices[0].index) + assertEquals("", openAICompletion.choices[0].finishReason) + assertEquals( + Usage( + promptTokens = null, + completionTokens = 10, + totalTokens = 10 + ), openAICompletion.usage + ) + } + + @Test + fun `should convert OllamaGenerateResponse to OpenAICompletion with all fields`() { val ollamaGenerateResponse = OllamaGenerateResponse.dummy() val openAICompletion = ollamaGenerateResponse.toOpenAICompletion() assertEquals("createdAt", openAICompletion.id) @@ -64,7 +132,21 @@ class ExtensionsTest { } @Test - fun `should convert OllamaChatResponse to OpenAIChatCompletion`() { + fun `should convert OllamaChatResponse to OpenAIChatCompletion without message`() { + val ollamaChatResponse = OllamaChatResponse( + createdAt = "123", + model = "llama2", + done = false + ) + val openAIChatCompletion = ollamaChatResponse.toOpenAIChatCompletion() + assertEquals("123", openAIChatCompletion.id) + assertEquals(1L, openAIChatCompletion.created) + assertEquals("llama2", openAIChatCompletion.model) + assertEquals(0, openAIChatCompletion.choices.size) + } + + @Test + fun `should convert OllamaChatResponse to OpenAIChatCompletion with all fields`() { val ollamaChatResponse = OllamaChatResponse( createdAt = "123", model = "llama2", @@ -84,6 +166,37 @@ class ExtensionsTest { assertEquals("assistant", openAIChatCompletion.choices[0].message.role.name) } + @Test + fun `should throw IllegalArgumentException when message not recognized`() { + val chatCompletionRequest = ChatCompletionRequest( + model = Model(OllamaModel.LLAMA2.value), + messages = listOf( + ChatMessage.vision(emptyList()) + ) + ) + + assertThrows(IllegalArgumentException::class.java) { + chatCompletionRequest.toOllamaChatRequest() + } + } + + @Test + fun `should throw IllegalArgumentException when role not recognized`() { + val chatCompletionRequest = ChatCompletionRequest( + model = Model(OllamaModel.LLAMA2.value), + messages = listOf( + ChatMessage.UserMessage( + content = "Hello", + role = Role.Tool + ) + ) + ) + + assertThrows(IllegalArgumentException::class.java) { + chatCompletionRequest.toOllamaChatRequest() + } + } + @Test fun `should convert ChatCompletionRequest to OllamaChatRequest`() { val chatCompletionRequest = ChatCompletionRequest(