Skip to content

Commit

Permalink
feat: support for Google Gemini (charmbracelet#314)
Browse files Browse the repository at this point in the history
* feat(stream): Filter out System role messages for Anthropic API requests

* feat(api): add support for Google API and Gemini models

* refactor(google): simplify variable declarations in google.go

* feat: add TopK parameter for token generation control

* fix(formatter): remove redundant default FormatText assignment

This was breaking the usage of the `--format` and `--format-as` flags
for every value other than `markdown`.

* chore(core): initialize FormatText config with default value if nil

---------

Co-authored-by: guzmonne <[email protected]>
  • Loading branch information
cloudbridgeuy and guzmonne authored Aug 28, 2024
1 parent ae04564 commit e9af3b6
Show file tree
Hide file tree
Showing 9 changed files with 380 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ Check the [`./features.md`](./features.md) for more details.
- `--fanciness`: Level of fanciness.
- `--temp`: Sampling temperature.
- `--topp`: Top P value.
- `--topk`: Top K value.

## Custom Roles

Expand Down
3 changes: 2 additions & 1 deletion anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ type AnthropicMessageCompletionRequest struct {
MaxTokens int `json:"max_tokens"`
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stream bool `json:"stream,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
}
Expand Down Expand Up @@ -253,7 +254,7 @@ func (stream *anthropicStreamReader) processLines() (openai.ChatCompletionStream
var chunk AnthropicCompletionMessageResponse
unmarshalErr := stream.unmarshaler.Unmarshal(noPrefixLine, &chunk)
if unmarshalErr != nil {
return *new(openai.ChatCompletionStreamResponse), fmt.Errorf("ollamaStreamReader.processLines: %w", unmarshalErr)
return *new(openai.ChatCompletionStreamResponse), fmt.Errorf("anthropicStreamReader.processLines: %w", unmarshalErr)
}

if chunk.Type != "content_block_delta" {
Expand Down
2 changes: 2 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ var help = map[string]string{
"temp": "Temperature (randomness) of results, from 0.0 to 2.0.",
"stop": "Up to 4 sequences where the API will stop generating further tokens.",
"topp": "TopP, an alternative to temperature that narrows response, from 0.0 to 1.0.",
"topk": "TopK, only sample from the top K options for each subsequent token.",
"fanciness": "Your desired level of fanciness.",
"status-text": "Text to show while generating.",
"settings": "Open settings in your $EDITOR.",
Expand Down Expand Up @@ -139,6 +140,7 @@ type Config struct {
Temperature float32 `yaml:"temp" env:"TEMP"`
Stop []string `yaml:"stop" env:"STOP"`
TopP float32 `yaml:"topp" env:"TOPP"`
TopK int `yaml:"topk" env:"TOPK"`
NoLimit bool `yaml:"no-limit" env:"NO_LIMIT"`
CachePath string `yaml:"cache-path" env:"CACHE_PATH"`
NoCache bool `yaml:"no-cache" env:"NO_CACHE"`
Expand Down
10 changes: 10 additions & 0 deletions config_template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ quiet: false
temp: 1.0
# {{ index .Help "topp" }}
topp: 1.0
# {{ index .Help "topk" }}
topk: 50
# {{ index .Help "no-limit" }}
no-limit: false
# {{ index .Help "word-wrap" }}
Expand Down Expand Up @@ -106,6 +108,14 @@ apis:
max-input-chars: 128000
command-r:
max-input-chars: 128000
google:
models:
gemini-1.5-pro-latest:
aliases: ["gemini"]
max-input-chars: 392000
gemini-1.5-flash-latest:
aliases: ["flash"]
max-input-chars: 392000
ollama:
base-url: http://localhost:11434/api
models: # https://ollama.com/library
Expand Down
280 changes: 280 additions & 0 deletions google.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
package main

import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"

openai "github.com/sashabaranov/go-openai"
)

var googleHeaderData = []byte("data: ")

// GoogleClientConfig represents the configuration for the Google API client.
type GoogleClientConfig struct {
BaseURL string
HTTPClient *http.Client
EmptyMessagesLimit uint
}

// DefaultGoogleConfig returns the default configuration for the Google API client.
func DefaultGoogleConfig(model, authToken string) GoogleClientConfig {
return GoogleClientConfig{
BaseURL: fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:streamGenerateContent?alt=sse&key=%s", model, authToken),
HTTPClient: &http.Client{},
EmptyMessagesLimit: defaultEmptyMessagesLimit,
}
}

// GoogleParts is a datatype containing media that is part of a multi-part Content message.
type GoogleParts struct {
Text string `json:"text,omitempty"`
}

// GoogleContent is the base structured datatype containing multi-part content of a message.
type GoogleContent struct {
Parts []GoogleParts `json:"parts,omitempty"`
Role string `json:"role,omitempty"`
}

// GoogleGenerationConfig are the options for model generation and outputs. Not all parameters are configurable for every model.
type GoogleGenerationConfig struct {
StopSequences []string `json:"stopSequences,omitempty"`
ResponseMimeType string `json:"responseMimeType,omitempty"`
CandidateCount uint `json:"candidateCount,omitempty"`
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
}

// GoogleMessageCompletionRequestOptions represents the valid parameters and value options for the request.
type GoogleMessageCompletionRequest struct {
Contents []GoogleContent `json:"contents,omitempty"`
GenerationConfig GoogleGenerationConfig `json:"generationConfig,omitempty"`
}

// GoogleRequestBuilder is an interface for building HTTP requests for the Google API.
type GoogleRequestBuilder interface {
Build(ctx context.Context, method, url string, body any, header http.Header) (*http.Request, error)
}

// NewGoogleRequestBuilder creates a new HTTPRequestBuilder.
func NewGoogleRequestBuilder() *HTTPRequestBuilder {
return &HTTPRequestBuilder{
marshaller: &JSONMarshaller{},
}
}

// GoogleClient is a client for the Anthropic API.
type GoogleClient struct {
config GoogleClientConfig

requestBuilder GoogleRequestBuilder
}

// NewGoogleClient creates a new AnthropicClient with the given configuration.
func NewGoogleClientWithConfig(config GoogleClientConfig) *GoogleClient {
return &GoogleClient{
config: config,
requestBuilder: NewGoogleRequestBuilder(),
}
}

func (c *GoogleClient) newRequest(ctx context.Context, method, url string, setters ...requestOption) (*http.Request, error) {
// Default Options
args := &requestOptions{
body: nil,
header: make(http.Header),
}
for _, setter := range setters {
setter(args)
}
req, err := c.requestBuilder.Build(ctx, method, url, args.body, args.header)
if err != nil {
return new(http.Request), err
}
return req, nil
}

func (c *GoogleClient) handleErrorResp(resp *http.Response) error {
// Print the response text
var errRes openai.ErrorResponse
err := json.NewDecoder(resp.Body).Decode(&errRes)
if err != nil || errRes.Error == nil {
reqErr := &openai.RequestError{
HTTPStatusCode: resp.StatusCode,
Err: err,
}
if errRes.Error != nil {
reqErr.Err = errRes.Error
}
return reqErr
}

errRes.Error.HTTPStatusCode = resp.StatusCode
return errRes.Error
}

// GoogleCandidates represents a response candidate generated from the model.
type GoogleCandidate struct {
Content GoogleContent `json:"content,omitempty"`
FinishReason string `json:"finishReason,omitempty"`
TokenCount uint `json:"tokenCount,omitempty"`
Index uint `json:"index,omitempty"`
}

// GoogleCompletionMessageResponse represents a response to an Google completion message.
type GoogleCompletionMessageResponse struct {
Candidates []GoogleCandidate `json:"candidates,omitempty"`
}

// GoogleChatCompletionStream represents a stream for chat completion.
type GoogleChatCompletionStream struct {
*googleStreamReader
}

type googleStreamReader struct {
emptyMessagesLimit uint
isFinished bool

reader *bufio.Reader
response *http.Response
errAccumulator ErrorAccumulator
unmarshaler Unmarshaler

httpHeader
}

// Recv reads the next response from the stream.
func (stream *googleStreamReader) Recv() (response openai.ChatCompletionStreamResponse, err error) {
if stream.isFinished {
err = io.EOF
return
}

response, err = stream.processLines()
return
}

// Close closes the stream.
func (stream *googleStreamReader) Close() error {
return stream.response.Body.Close() //nolint:wrapcheck
}

//nolint:gocognit
func (stream *googleStreamReader) processLines() (openai.ChatCompletionStreamResponse, error) {
var (
emptyMessagesCount uint
hasError bool
)

for {
rawLine, readErr := stream.reader.ReadBytes('\n')

if readErr != nil {
return *new(openai.ChatCompletionStreamResponse), fmt.Errorf("googleStreamReader.processLines: %w", readErr)
}

noSpaceLine := bytes.TrimSpace(rawLine)

if bytes.HasPrefix(noSpaceLine, errorPrefix) {
hasError = true
// NOTE: Continue to the next event to get the error data.
continue
}

if !bytes.HasPrefix(noSpaceLine, googleHeaderData) || hasError {
if hasError {
noSpaceLine = bytes.TrimPrefix(noSpaceLine, googleHeaderData)
}
writeErr := stream.errAccumulator.Write(noSpaceLine)
if writeErr != nil {
return *new(openai.ChatCompletionStreamResponse), fmt.Errorf("ollamaStreamReader.processLines: %w", writeErr)
}
emptyMessagesCount++
if emptyMessagesCount > stream.emptyMessagesLimit {
return *new(openai.ChatCompletionStreamResponse), ErrTooManyEmptyStreamMessages
}
continue
}

noPrefixLine := bytes.TrimPrefix(noSpaceLine, googleHeaderData)

var chunk GoogleCompletionMessageResponse
unmarshalErr := stream.unmarshaler.Unmarshal(noPrefixLine, &chunk)
if unmarshalErr != nil {
return *new(openai.ChatCompletionStreamResponse), fmt.Errorf("googleStreamReader.processLines: %w", unmarshalErr)
}

// NOTE: Leverage the existing logic based on OpenAI ChatCompletionStreamResponse by
// converting the Anthropic events into them.
if len(chunk.Candidates) == 0 {
continue
}
parts := chunk.Candidates[0].Content.Parts
if len(parts) == 0 {
continue
}
response := openai.ChatCompletionStreamResponse{
Choices: []openai.ChatCompletionStreamChoice{
{
Index: 0,
Delta: openai.ChatCompletionStreamChoiceDelta{
Content: chunk.Candidates[0].Content.Parts[0].Text,
Role: "assistant",
},
},
},
}

return response, nil
}
}

func googleSendRequestStream(client *GoogleClient, req *http.Request) (*googleStreamReader, error) {
req.Header.Set("content-type", "application/json")

resp, err := client.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
if err != nil {
return new(googleStreamReader), err
}
if isFailureStatusCode(resp) {
return new(googleStreamReader), client.handleErrorResp(resp)
}
return &googleStreamReader{
emptyMessagesLimit: client.config.EmptyMessagesLimit,
reader: bufio.NewReader(resp.Body),
response: resp,
errAccumulator: NewErrorAccumulator(),
unmarshaler: &JSONUnmarshaler{},
httpHeader: httpHeader(resp.Header),
}, nil
}

// CreateChatCompletionStream — API call to create a chat completion w/ streaming
// support. It sets whether to stream back partial progress. If set, tokens will be
// sent as data-only server-sent events as they become available, with the
// stream terminated by a data: [DONE] message.
func (c *GoogleClient) CreateChatCompletionStream(
ctx context.Context,
request GoogleMessageCompletionRequest,
) (stream *GoogleChatCompletionStream, err error) {
req, err := c.newRequest(ctx, http.MethodPost, c.config.BaseURL, withBody(request))
if err != nil {
return nil, err
}

resp, err := googleSendRequestStream(c, req)
if err != nil {
return
}
stream = &GoogleChatCompletionStream{
googleStreamReader: resp,
}
return
}
6 changes: 5 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ func initFlags() {
flags.Float32Var(&config.Temperature, "temp", config.Temperature, stdoutStyles().FlagDesc.Render(help["temp"]))
flags.StringArrayVar(&config.Stop, "stop", config.Stop, stdoutStyles().FlagDesc.Render(help["stop"]))
flags.Float32Var(&config.TopP, "topp", config.TopP, stdoutStyles().FlagDesc.Render(help["topp"]))
flags.IntVar(&config.TopK, "topk", config.TopK, stdoutStyles().FlagDesc.Render(help["topk"]))
flags.UintVar(&config.Fanciness, "fanciness", config.Fanciness, stdoutStyles().FlagDesc.Render(help["fanciness"]))
flags.StringVar(&config.StatusText, "status-text", config.StatusText, stdoutStyles().FlagDesc.Render(help["status-text"]))
flags.BoolVar(&config.NoCache, "no-cache", config.NoCache, stdoutStyles().FlagDesc.Render(help["no-cache"]))
Expand All @@ -270,9 +271,12 @@ func initFlags() {
return roleNames(toComplete), cobra.ShellCompDirectiveDefault
})

if config.FormatText == nil {
config.FormatText = defaultConfig().FormatText
}

if config.Format && config.FormatAs == "" {
config.FormatAs = "markdown"
config.FormatText = defaultConfig().FormatText
}

if config.Format && config.FormatAs != "" && config.FormatText[config.FormatAs] == "" {
Expand Down
9 changes: 9 additions & 0 deletions mods.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ func (m *Mods) startCompletionCmd(content string) tea.Cmd {
var accfg AnthropicClientConfig
var cccfg CohereClientConfig
var occfg OllamaClientConfig
var gccfg GoogleClientConfig

cfg := m.Config
mod, ok = cfg.Models[cfg.Model]
Expand Down Expand Up @@ -326,6 +327,12 @@ func (m *Mods) startCompletionCmd(content string) tea.Cmd {
if api.Version != "" {
accfg.Version = AnthropicAPIVersion(api.Version)
}
case "google":
key, err := m.ensureKey(api, "GOOGLE_API_KEY", "https://aistudio.google.com/app/apikey")
if err != nil {
return err
}
gccfg = DefaultGoogleConfig(mod.Name, key)
case "cohere":
key, err := m.ensureKey(api, "COHERE_API_KEY", "https://dashboard.cohere.com/api-keys")
if err != nil {
Expand Down Expand Up @@ -374,6 +381,8 @@ func (m *Mods) startCompletionCmd(content string) tea.Cmd {
switch mod.API {
case "anthropic":
return m.createAnthropicStream(content, accfg, mod)
case "google":
return m.createGoogleStream(content, gccfg, mod)
case "cohere":
return m.createCohereStream(content, cccfg, mod)
case "ollama":
Expand Down
Loading

0 comments on commit e9af3b6

Please sign in to comment.