Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align proto primitives #187

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .changes/common/carpenter-beggar-creator-celery.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MAJOR","changes":["Better align protos in regards to primitive defaults."]}
1 change: 1 addition & 0 deletions .changes/generativeai/breath-brush-achiever-boat.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MAJOR","changes":["Better align protos in regards to primitive defaults."]}
71 changes: 71 additions & 0 deletions common/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Protos

> [!NOTE]
> Some code and documentation may refer to "Generative AI" as "labs". These two names are used
> interchangeably, and you should just register them as different names for the same service.

Protos are derived from a combination of the [Generative AI proto files](https://github.com/googleapis/googleapis/tree/master/google/ai/generativelanguage/v1beta)
and the [Vertex AI proto files](https://github.com/googleapis/googleapis/tree/master/google/cloud/aiplatform/v1beta1).

The goal is to maintain a sort of overlap between the two protos- representing their "common"
definitions.

## Organization

Within this SDK, the protos are defined under the following three categories.

### [Client](#client-protos)

You can find these types [here](https://github.com/google-gemini/generative-ai-android/blob/main/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt).

These are types that can only be sent _to_ the server; meaning the server will never respond
with them.

You can classify them as "client" only types, or "request" types.

### [Server](#server-protos)

You can find these types [here](https://github.com/google-gemini/generative-ai-android/blob/main/common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt).

These are types that can only be sent _from_ the server; meaning the client will never create them
on their own.

You can classify them as "server" only types, or "response" types.

### [Shared](#shared-protos)

You can find these types [here](https://github.com/google-gemini/generative-ai-android/blob/main/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt).

These are types that can both be sent _to_ and received _from_ the server; meaning the client can
create them, and the server can also respond with them.

You can classify them as "shared" types, or "common" types.

## Alignment efforts

In aligning with the proto, you should be mindful of the following practices:

### Field presence

Additional Context: [Presence in Proto3 APIs](https://github.com/google-gemini/generative-ai-android/blob/main/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt)

- `optional` types should be nullable.
- non `optional` primitive types (including enums) should default to their [respective default](https://protobuf.dev/programming-guides/proto3/#default).
- `repeated` fields that are not marked with a `google.api.field_behavior` of `REQUIRED` should
default to an empty list or map.
- message fields that are marked with a `google.api.field_behavior` of `OPTIONAL` should be nullable.
- fields that are marked with a `google.api.field_behavior` of `REQUIRED` should *NOT* have a
default value, but *ONLY* when it's a [client](#client-protos) or [shared](#shared-protos) type.
- if a field is marked with both `optional` and a `google.api.field_behavior` of `REQUIRED`, then it
should be a nullable field that does _not_ default to null (ie; it needs to be explicitly set).

### Serial names

> [!NOTE]
> The exception to this rule is ENUM fields, which DO use `snake_case` serial names.

While the proto is defined in `snake_case`, it will respect and respond in `camelCase` if you send
the request in `camelCase`. As such, our protos do not have `@SerialName` annotations denoting their
`snake_case` alternative.

So all your fields should be defined in `camelCase` format.
Original file line number Diff line number Diff line change
Expand Up @@ -235,19 +235,19 @@ private suspend fun validateResponse(response: HttpResponse) {
if (message.contains("quota")) {
throw QuotaExceededException(message)
}
if (error.details?.any { "SERVICE_DISABLED" == it.reason } == true) {
if (error.details.any { "SERVICE_DISABLED" == it.reason }) {
throw ServiceDisabledException(message)
}
throw ServerException(message)
}

private fun GenerateContentResponse.validate() = apply {
if ((candidates?.isEmpty() != false) && promptFeedback == null) {
if (candidates.isEmpty() && promptFeedback == null) {
throw SerializationException("Error deserializing response, found no valid fields")
}
promptFeedback?.blockReason?.let { throw PromptBlockedException(this) }
candidates
?.mapNotNull { it.finishReason }
?.firstOrNull { it != FinishReason.STOP }
.map { it.finishReason }
.firstOrNull { it != FinishReason.STOP }
?.let { throw ResponseStoppedException(this) }
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class InvalidStateException(message: String, cause: Throwable? = null) :
*/
class ResponseStoppedException(val response: GenerateContentResponse, cause: Throwable? = null) :
GoogleGenerativeAIException(
"Content generation stopped. Reason: ${response.candidates?.first()?.finishReason?.name}",
"Content generation stopped. Reason: ${response.candidates.first().finishReason?.name}",
cause,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

@file:OptIn(ExperimentalSerializationApi::class)

package com.google.ai.client.generativeai.common

import com.google.ai.client.generativeai.common.client.GenerationConfig
Expand All @@ -22,45 +24,41 @@ import com.google.ai.client.generativeai.common.client.ToolConfig
import com.google.ai.client.generativeai.common.shared.Content
import com.google.ai.client.generativeai.common.shared.SafetySetting
import com.google.ai.client.generativeai.common.util.fullModelName
import kotlinx.serialization.SerialName
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.Serializable

sealed interface Request

@Serializable
data class GenerateContentRequest(
val model: String? = null,
val model: String,
val contents: List<Content>,
@SerialName("safety_settings") val safetySettings: List<SafetySetting>? = null,
@SerialName("generation_config") val generationConfig: GenerationConfig? = null,
val tools: List<Tool>? = null,
@SerialName("tool_config") var toolConfig: ToolConfig? = null,
@SerialName("system_instruction") val systemInstruction: Content? = null,
val safetySettings: List<SafetySetting> = emptyList(),
val generationConfig: GenerationConfig? = null,
val tools: List<Tool> = emptyList(),
val toolConfig: ToolConfig? = null,
val systemInstruction: Content? = null,
) : Request

@Serializable
data class CountTokensRequest(
val model: String,
val contents: List<Content> = emptyList(),
val tools: List<Tool> = emptyList(),
val generateContentRequest: GenerateContentRequest? = null,
val model: String? = null,
val contents: List<Content>? = null,
val tools: List<Tool>? = null,
@SerialName("system_instruction") val systemInstruction: Content? = null,
val systemInstruction: Content? = null,
) : Request {
companion object {
fun forGenAI(generateContentRequest: GenerateContentRequest) =
CountTokensRequest(
generateContentRequest =
generateContentRequest.model?.let {
generateContentRequest.copy(model = fullModelName(it))
} ?: generateContentRequest
)
fun forGenAI(request: GenerateContentRequest) =
CountTokensRequest(fullModelName(request.model), request.contents, emptyList(), request)

fun forVertexAI(generateContentRequest: GenerateContentRequest) =
fun forVertexAI(request: GenerateContentRequest) =
CountTokensRequest(
model = generateContentRequest.model?.let { fullModelName(it) },
contents = generateContentRequest.contents,
tools = generateContentRequest.tools,
systemInstruction = generateContentRequest.systemInstruction,
fullModelName(request.model),
request.contents,
request.tools,
null,
request.systemInstruction,
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,20 @@ sealed interface Response

@Serializable
data class GenerateContentResponse(
val candidates: List<Candidate>? = null,
val candidates: List<Candidate> = emptyList(),
val promptFeedback: PromptFeedback? = null,
val usageMetadata: UsageMetadata? = null,
) : Response

@Serializable
data class CountTokensResponse(val totalTokens: Int, val totalBillableCharacters: Int? = null) :
data class CountTokensResponse(val totalTokens: Int = 0, val totalBillableCharacters: Int = 0) :
Response

@Serializable data class GRpcErrorResponse(val error: GRpcError) : Response

@Serializable
data class UsageMetadata(
val promptTokenCount: Int? = null,
val candidatesTokenCount: Int? = null,
val totalTokenCount: Int? = null,
val promptTokenCount: Int = 0,
val candidatesTokenCount: Int = 0,
val totalTokenCount: Int = 0,
)
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,29 @@ import kotlinx.serialization.json.JsonObject

@Serializable
data class GenerationConfig(
val temperature: Float?,
@SerialName("top_p") val topP: Float?,
@SerialName("top_k") val topK: Int?,
@SerialName("candidate_count") val candidateCount: Int?,
@SerialName("max_output_tokens") val maxOutputTokens: Int?,
@SerialName("stop_sequences") val stopSequences: List<String>?,
@SerialName("response_mime_type") val responseMimeType: String? = null,
@SerialName("presence_penalty") val presencePenalty: Float? = null,
@SerialName("frequency_penalty") val frequencyPenalty: Float? = null,
@SerialName("response_schema") val responseSchema: Schema? = null,
val temperature: Float? = null,
val topP: Float? = null,
val topK: Int? = null,
val candidateCount: Int? = null,
val maxOutputTokens: Int? = null,
val stopSequences: List<String> = emptyList(),
val responseMimeType: String? = null,
val presencePenalty: Float? = null,
val frequencyPenalty: Float? = null,
val responseSchema: Schema? = null,
)

@Serializable
data class Tool(
val functionDeclarations: List<FunctionDeclaration>? = null,
val functionDeclarations: List<FunctionDeclaration> = emptyList(),
// This is a json object because it is not possible to make a data class with no parameters.
val codeExecution: JsonObject? = null,
)

@Serializable
data class ToolConfig(
@SerialName("function_calling_config") val functionCallingConfig: FunctionCallingConfig
)
@Serializable data class ToolConfig(val functionCallingConfig: FunctionCallingConfig? = null)

@Serializable
data class FunctionCallingConfig(val mode: Mode) {
data class FunctionCallingConfig(val mode: Mode = Mode.UNSPECIFIED) {
@Serializable
enum class Mode {
@SerialName("MODE_UNSPECIFIED") UNSPECIFIED,
Expand All @@ -58,16 +55,20 @@ data class FunctionCallingConfig(val mode: Mode) {
}

@Serializable
data class FunctionDeclaration(val name: String, val description: String, val parameters: Schema)
data class FunctionDeclaration(
val name: String,
val description: String,
val parameters: Schema? = null,
)

@Serializable
data class Schema(
val type: String,
val description: String? = null,
val format: String? = null,
val nullable: Boolean? = false,
val enum: List<String>? = null,
val properties: Map<String, Schema>? = null,
val required: List<String>? = null,
val description: String = "",
val format: String = "",
val nullable: Boolean = false,
val enum: List<String> = emptyList(),
val properties: Map<String, Schema> = emptyMap(),
val required: List<String> = emptyList(),
val items: Schema? = null,
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

@file:OptIn(ExperimentalSerializationApi::class)

package com.google.ai.client.generativeai.common.server

import com.google.ai.client.generativeai.common.shared.Content
Expand All @@ -36,8 +38,10 @@ object FinishReasonSerializer :

@Serializable
data class PromptFeedback(
// TODO() should default to UNSPECIFIED, but that would be an unexpected change for consumers null
// checking block reason to see if their prompt was blocked
val blockReason: BlockReason? = null,
val safetyRatings: List<SafetyRating>? = null,
val safetyRatings: List<SafetyRating> = emptyList(),
)

@Serializable(BlockReasonSerializer::class)
Expand All @@ -51,60 +55,54 @@ enum class BlockReason {
@Serializable
data class Candidate(
val content: Content? = null,
// TODO() should default to UNSPECIFIED, but that would be an unexpected change for consumers
// checking if their finish reason is anything other than STOP
val finishReason: FinishReason? = null,
val safetyRatings: List<SafetyRating>? = null,
val safetyRatings: List<SafetyRating> = emptyList(),
val citationMetadata: CitationMetadata? = null,
val groundingMetadata: GroundingMetadata? = null,
)

@Serializable
data class CitationMetadata
@OptIn(ExperimentalSerializationApi::class)
constructor(@JsonNames("citations") val citationSources: List<CitationSources>)
data class CitationMetadata(
@JsonNames("citations") val citationSources: List<CitationSources> = emptyList()
)

@Serializable
data class CitationSources(
val startIndex: Int = 0,
val endIndex: Int,
val uri: String,
val startIndex: Int? = null,
val endIndex: Int? = null,
val uri: String? = null,
val license: String? = null,
)

@Serializable
data class SafetyRating(
val category: HarmCategory,
val probability: HarmProbability,
val blocked: Boolean? = null, // TODO(): any reason not to default to false?
val probabilityScore: Float? = null,
val severity: HarmSeverity? = null,
val severityScore: Float? = null,
val category: HarmCategory = HarmCategory.UNSPECIFIED,
val probability: HarmProbability = HarmProbability.UNSPECIFIED,
val blocked: Boolean = false,
val probabilityScore: Float = 0f,
val severity: HarmSeverity = HarmSeverity.UNSPECIFIED,
val severityScore: Float = 0f,
)

@Serializable
data class GroundingMetadata(
@SerialName("web_search_queries") val webSearchQueries: List<String>?,
@SerialName("search_entry_point") val searchEntryPoint: SearchEntryPoint?,
@SerialName("retrieval_queries") val retrievalQueries: List<String>?,
@SerialName("grounding_attribution") val groundingAttribution: List<GroundingAttribution>?,
val webSearchQueries: List<String> = emptyList(),
val searchEntryPoint: SearchEntryPoint? = null,
val retrievalQueries: List<String> = emptyList(),
val groundingAttribution: List<GroundingAttribution> = emptyList(),
)

@Serializable
data class SearchEntryPoint(
@SerialName("rendered_content") val renderedContent: String?,
@SerialName("sdk_blob") val sdkBlob: String?,
)
data class SearchEntryPoint(val renderedContent: String = "", val sdkBlob: String = "")

// TODO() Has a different definition for labs vs vertex. May need to split into diff types in future
// (when labs supports it)
@Serializable
data class GroundingAttribution(
val segment: Segment,
@SerialName("confidence_score") val confidenceScore: Float?,
)
data class GroundingAttribution(val segment: Segment? = null, val confidenceScore: Float? = null)

@Serializable
data class Segment(
@SerialName("start_index") val startIndex: Int,
@SerialName("end_index") val endIndex: Int,
)
@Serializable data class Segment(val startIndex: Int = 0, val endIndex: Int = 0)

@Serializable(HarmProbabilitySerializer::class)
enum class HarmProbability {
Expand Down
Loading
Loading