Skip to content

Commit

Permalink
feat(BE-190): As a user, i want able to use ollama-client
Browse files Browse the repository at this point in the history
 - add OllamaApi
 - add unit test
 - code clean & refactor
  • Loading branch information
hanrw committed Apr 16, 2024
1 parent 8ff4560 commit 619da0c
Show file tree
Hide file tree
Showing 24 changed files with 435 additions and 162 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ fun anthropicModules(
single<HttpRequester>(named("anthropicHttpRequester")) {
HttpRequester.default(
createHttpClient(
url = config.baseUrl,
host = config.baseUrl,
json = get(),
)
)
Expand Down
2 changes: 2 additions & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ sonatypePortalPublisher {
dependencies {
kover(projects.openaiClient.openaiClientCore)
kover(projects.anthropicClient.anthropicClientCore)
kover(projects.openaiGateway.openaiGatewayCore)
kover(projects.ollamaClient.ollamaClientCore)
}

val autoVersion = project.property(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ actual fun HttpRequester.Companion.default(
): HttpRequester {
return DefaultHttpRequester(
createHttpClient(
url = { url },
host = { url },
authToken = { token },
json = JsonLenient
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ actual fun HttpRequester.Companion.default(
): HttpRequester {
return DefaultHttpRequester(
createHttpClient(
url = { url },
host = { url },
authToken = { token },
json = JsonLenient
)
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,12 +1,49 @@
package com.tddworks.ollama.api

import com.tddworks.ollama.api.chat.OllamaChat
import com.tddworks.ollama.api.internal.OllamaApi

/**
* @author hanrw
* @date 2024/4/14 17:32
* Interface for interacting with the Ollama API.
*/
class Ollama {
interface Ollama : OllamaChat {

companion object {
const val BASE_URL = "https://ollama.com"
const val ANTHROPIC_VERSION = "1.0.0"
const val BASE_URL = "localhost"
const val PORT = 11434
const val PROTOCOL = "http"
}

/**
* This function returns the base URL as a string.
*
* @return a string representing the base URL
*/
fun baseUrl(): String

/**
* This function returns the port as an integer.
*
* @return an integer representing the port
*/
fun port(): Int

/**
* This function returns the protocol as a string.
*
* @return a string representing the protocol
*/
fun protocol(): String
}

fun Ollama(
baseUrl: () -> String = { Ollama.BASE_URL },
port: () -> Int = { Ollama.PORT },
protocol: () -> String = { Ollama.PROTOCOL },
): Ollama {
return OllamaApi(
baseUrl = baseUrl(),
port = port(),
protocol = protocol()
)
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package com.tddworks.ollama.api
import org.koin.core.component.KoinComponent

data class OllamaConfig(
val apiKey: () -> String = { "CONFIG_API_KEY" },
val baseUrl: () -> String = { Ollama.BASE_URL },
val ollamaVersion: () -> String = { Ollama.ANTHROPIC_VERSION },
val protocol: () -> String = { Ollama.PROTOCOL },
val port: () -> Int = { Ollama.PORT },
) : KoinComponent
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.tddworks.ollama.api

import kotlinx.serialization.Serializable
import kotlin.jvm.JvmInline

@Serializable
@JvmInline
value class OllamaModel(val value: String) {
companion object {
val LLAMA2 = OllamaModel("llama2")
val CODE_LLAMA = OllamaModel("codellama")
val MISTRAL = OllamaModel("mistral")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.tddworks.ollama.api.chat

import kotlinx.coroutines.flow.Flow

interface OllamaChatApi {
interface OllamaChat {
suspend fun stream(request: OllamaChatRequest): Flow<OllamaChatResponse>
suspend fun request(request: OllamaChatRequest): OllamaChatResponse
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package com.tddworks.ollama.api.chat

import com.tddworks.common.network.api.StreamableRequest
import com.tddworks.common.network.api.StreamableRequest.Companion.STREAM
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.*


@Serializable
Expand All @@ -11,9 +13,17 @@ data class OllamaChatRequest(
@SerialName("messages") val messages: List<OllamaChatMessage>,
@SerialName("format") val format: String? = null,
// @SerialName("options") val options: Map<String, Any>? = null,
// @SerialName("stream") val stream: Boolean? = null,
@SerialName("keep_alive") val keepAlive: String? = null,
) : StreamableRequest
) : StreamableRequest {
fun asNonStreaming(jsonLenient: Json): JsonElement {
return jsonLenient.encodeToJsonElement(this)
.jsonObject.toMutableMap()
.apply {
put(STREAM, JsonPrimitive(false))
}
.let { JsonObject(it) }
}
}


@Serializable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,7 @@ import kotlinx.serialization.Serializable
* },
* "done": false
* }
*/
@Serializable
data class OllamaChatResponse(
@SerialName("model") val model: String,
@SerialName("created_at") val createdAt: String,
@SerialName("message") val message: OllamaChatMessage? = null,
@SerialName("done") val done: Boolean?,
@SerialName("total_duration") val totalDuration: Long? = null,
@SerialName("load_duration") val loadDuration: Long? = null,
@SerialName("prompt_eval_count") val promptEvalCount: Int? = null,
@SerialName("prompt_eval_duration") val promptEvalDuration: Long? = null,
@SerialName("eval_count") val evalCount: Int? = null,
@SerialName("eval_duration") val evalDuration: Long? = null,
)

/**
* ======== final response ========
* {
* "model": "llama2",
* "created_at": "2023-08-04T19:22:45.499127Z",
Expand All @@ -40,16 +25,35 @@ data class OllamaChatResponse(
* "eval_count": 468,
* "eval_duration": 7701267000
* }
*
* ======= Non-streaming response =======
* {
* "model": "llama2",
* "created_at": "2023-12-12T14:13:43.416799Z",
* "message": {
* "role": "assistant",
* "content": "Hello! How are you today?"
* },
* "done": true,
* "total_duration": 5191566416,
* "load_duration": 2154458,
* "prompt_eval_count": 26,
* "prompt_eval_duration": 383809000,
* "eval_count": 298,
* "eval_duration": 4799921000
* }
*/
@Serializable
data class FinalOllamaChatResponse(
data class OllamaChatResponse(
@SerialName("model") val model: String,
@SerialName("created_at") val createdAt: String,
@SerialName("done") val done: Boolean?,
@SerialName("total_duration") val totalDuration: Long?,
@SerialName("load_duration") val loadDuration: Long?,
@SerialName("prompt_eval_count") val promptEvalCount: Int?,
@SerialName("prompt_eval_duration") val promptEvalDuration: Long?,
@SerialName("eval_count") val evalCount: Int?,
@SerialName("eval_duration") val evalDuration: Long?,
@SerialName("message") val message: OllamaChatMessage? = null,
@SerialName("done") val done: Boolean,
// Below are the fields that are for final response or non-streaming response
@SerialName("total_duration") val totalDuration: Long? = null,
@SerialName("load_duration") val loadDuration: Long? = null,
@SerialName("prompt_eval_count") val promptEvalCount: Int? = null,
@SerialName("prompt_eval_duration") val promptEvalDuration: Long? = null,
@SerialName("eval_count") val evalCount: Int? = null,
@SerialName("eval_duration") val evalDuration: Long? = null,
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package com.tddworks.ollama.api.chat.internal
import com.tddworks.common.network.api.ktor.api.HttpRequester
import com.tddworks.common.network.api.ktor.api.performRequest
import com.tddworks.common.network.api.ktor.api.streamRequest
import com.tddworks.ollama.api.chat.OllamaChatApi
import com.tddworks.ollama.api.chat.OllamaChat
import com.tddworks.ollama.api.chat.OllamaChatRequest
import com.tddworks.ollama.api.chat.OllamaChatResponse
import io.ktor.client.request.*
Expand All @@ -14,7 +14,7 @@ import kotlinx.serialization.json.Json
class DefaultOllamaChatApi(
private val requester: HttpRequester,
private val jsonLenient: Json = JsonLenient,
) : OllamaChatApi {
) : OllamaChat {
override suspend fun stream(request: OllamaChatRequest): Flow<OllamaChatResponse> {
return requester.streamRequest<OllamaChatResponse> {
method = HttpMethod.Post
Expand All @@ -33,7 +33,7 @@ class DefaultOllamaChatApi(
return requester.performRequest {
method = HttpMethod.Post
url(path = CHAT_API_PATH)
setBody(request)
setBody(request.asNonStreaming(jsonLenient))
contentType(ContentType.Application.Json)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.tddworks.ollama.api.internal

import com.tddworks.di.getInstance
import com.tddworks.ollama.api.Ollama
import com.tddworks.ollama.api.chat.OllamaChat

class OllamaApi(
private val baseUrl: String,
private val port: Int,
private val protocol: String,
) : Ollama, OllamaChat by getInstance() {

override fun baseUrl(): String {
return baseUrl
}

override fun port(): Int {
return port
}

override fun protocol(): String {
return protocol
}

}
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
package com.tddworks.ollama.di

import com.tddworks.common.network.api.ktor.api.HttpRequester
import com.tddworks.common.network.api.ktor.internal.createHttpClient
import com.tddworks.common.network.api.ktor.internal.default
import com.tddworks.di.commonModule
import com.tddworks.ollama.api.Ollama
import com.tddworks.ollama.api.OllamaConfig
import com.tddworks.ollama.api.chat.OllamaChat
import com.tddworks.ollama.api.chat.internal.DefaultOllamaChatApi
import com.tddworks.ollama.api.chat.internal.JsonLenient
import kotlinx.serialization.json.Json
import org.koin.core.context.startKoin
import org.koin.core.qualifier.named
import org.koin.dsl.KoinAppDeclaration
import org.koin.dsl.module

Expand All @@ -16,8 +24,32 @@ fun iniOllamaKoin(config: OllamaConfig, appDeclaration: KoinAppDeclaration = {})
fun ollamaModules(
config: OllamaConfig,
) = module {

single<Ollama> {
Ollama(
baseUrl = config.baseUrl,
port = config.port,
protocol = config.protocol
)
}

single<Json>(named("ollamaJson")) { JsonLenient }

single<HttpRequester>(named("ollamaHttpRequester")) {
HttpRequester.default(
createHttpClient(
protocol = config.protocol,
port = config.port,
host = config.baseUrl,
json = get(named("ollamaJson")),
)
)
}

single<OllamaChat> {
DefaultOllamaChatApi(
jsonLenient = get(named("ollamaJson")),
requester = get(named("ollamaHttpRequester"))
)
}
}
Loading

0 comments on commit 619da0c

Please sign in to comment.