Skip to content

Commit

Permalink
Merge branch 'main' into namwoam/mistral
Browse files Browse the repository at this point in the history
  • Loading branch information
donch1989 authored Jul 15, 2024
2 parents 31ebe22 + cba4aac commit d850b8b
Show file tree
Hide file tree
Showing 43 changed files with 930 additions and 476 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ repos:
- id: check-json
- id: check-merge-conflict
- id: end-of-file-fixer
exclude: tools/compogen/cmd/testdata
exclude: (?i).*testdata/
exclude_types: [svg,mdx]
- id: trailing-whitespace
- id: pretty-format-json
Expand Down
16 changes: 8 additions & 8 deletions ai/cohere/v0/component_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ func TestComponent_Tasks(t *testing.T) {

commandTc := struct {
input map[string]any
wantResp textGenerationOutput
wantResp TextGenerationOutput
}{
input: map[string]any{"model-name": "command-r-plus"},
wantResp: textGenerationOutput{Text: "Hi! My name is command-r-plus.", Citations: []citation{}, Usage: commandUsage{InputTokens: 20, OutputTokens: 30}},
wantResp: TextGenerationOutput{Text: "Hi! My name is command-r-plus.", Citations: []citation{}, Usage: commandUsage{InputTokens: 20, OutputTokens: 30}},
}

c.Run("ok - task command", func(c *qt.C) {
Expand Down Expand Up @@ -92,10 +92,10 @@ func TestComponent_Tasks(t *testing.T) {

embedFloatTc := struct {
input map[string]any
wantResp embeddingFloatOutput
wantResp EmbeddingFloatOutput
}{
input: map[string]any{"text": "abcde"},
wantResp: embeddingFloatOutput{Embedding: []float64{0.1, 0.2, 0.3, 0.4, 0.5}, Usage: embedUsage{Tokens: 20}},
wantResp: EmbeddingFloatOutput{Embedding: []float64{0.1, 0.2, 0.3, 0.4, 0.5}, Usage: embedUsage{Tokens: 20}},
}

c.Run("ok - task float embed", func(c *qt.C) {
Expand Down Expand Up @@ -123,10 +123,10 @@ func TestComponent_Tasks(t *testing.T) {

embedIntTc := struct {
input map[string]any
wantResp embeddingIntOutput
wantResp EmbeddingIntOutput
}{
input: map[string]any{"text": "abcde", "embedding-type": "int8"},
wantResp: embeddingIntOutput{Embedding: []int{1, 2, 3, 4, 5}, Usage: embedUsage{Tokens: 20}},
wantResp: EmbeddingIntOutput{Embedding: []int{1, 2, 3, 4, 5}, Usage: embedUsage{Tokens: 20}},
}

c.Run("ok - task int embed", func(c *qt.C) {
Expand Down Expand Up @@ -154,10 +154,10 @@ func TestComponent_Tasks(t *testing.T) {

rerankTc := struct {
input map[string]any
wantResp rerankOutput
wantResp RerankOutput
}{
input: map[string]any{"documents": []string{"a", "b", "c", "d"}},
wantResp: rerankOutput{Ranking: []string{"d", "c", "b", "a"}, Usage: rerankUsage{Search: 5}, Relevance: []float64{10, 9, 8, 7}},
wantResp: RerankOutput{Ranking: []string{"d", "c", "b", "a"}, Usage: rerankUsage{Search: 5}, Relevance: []float64{10, 9, 8, 7}},
}
c.Run("ok - task rerank", func(c *qt.C) {
setup, err := structpb.NewStruct(map[string]any{
Expand Down
31 changes: 31 additions & 0 deletions ai/cohere/v0/config/tasks.json
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,17 @@
"instillAcceptFormats": [
"string"
],
"instillCredentialMap": {
"values": [
"command-r-plus",
"command-r",
"command",
"command-light"
],
"targets": [
"setup.api-key"
]
},
"instillUIOrder": 0,
"instillUpstreamTypes": [
"value",
Expand Down Expand Up @@ -436,6 +447,17 @@
"instillAcceptFormats": [
"string"
],
"instillCredentialMap": {
"values": [
"embed-english-v3.0",
"embed-multilingual-v3.0",
"embed-english-light-v3.0",
"embed-multilingual-light-v3.0"
],
"targets": [
"setup.api-key"
]
},
"instillUIOrder": 0,
"instillUpstreamTypes": [
"value",
Expand Down Expand Up @@ -561,6 +583,15 @@
"instillAcceptFormats": [
"string"
],
"instillCredentialMap": {
"values": [
"rerank-english-v3.0",
"rerank-multilingual-v3.0"
],
"targets": [
"setup.api-key"
]
},
"instillUIOrder": 0,
"instillUpstreamTypes": [
"value",
Expand Down
166 changes: 102 additions & 64 deletions ai/cohere/v0/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@ import (
"google.golang.org/protobuf/types/known/structpb"
)

type embeddingInput struct {
type EmbeddingInput struct {
Text string `json:"text"`
ModelName string `json:"model-name"`
InputType string `json:"input-type"`
EmbeddingType string `json:"embedding-type"`
}

type embeddingFloatOutput struct {
type EmbeddingFloatOutput struct {
Usage embedUsage `json:"usage"`
Embedding []float64 `json:"embedding"`
}

type embeddingIntOutput struct {
type EmbeddingIntOutput struct {
Usage embedUsage `json:"usage"`
Embedding []int `json:"embedding"`
}
Expand All @@ -30,24 +30,77 @@ type embedUsage struct {
}

func (e *execution) taskEmbedding(in *structpb.Struct) (*structpb.Struct, error) {
inputStruct := embeddingInput{}
inputStruct := EmbeddingInput{}
err := base.ConvertFromStructpb(in, &inputStruct)
if err != nil {
return nil, fmt.Errorf("error generating input struct: %v", err)
}

if IsEmbeddingOutputInt(inputStruct.EmbeddingType) {
tokenCount, embedding, err := processWithIntOutput(e, inputStruct)
if err != nil {
return nil, err
}

outputStruct := EmbeddingIntOutput{
Usage: embedUsage{
Tokens: tokenCount,
},
Embedding: embedding,
}
output, err := base.ConvertToStructpb(outputStruct)
if err != nil {
return nil, err
}
return output, nil
}

tokenCount, embedding, err := processWithFloatOutput(e, inputStruct)
if err != nil {
return nil, err
}
outputStruct := EmbeddingFloatOutput{
Usage: embedUsage{
Tokens: tokenCount,
},
Embedding: embedding,
}
output, err := base.ConvertToStructpb(outputStruct)
if err != nil {
return nil, err
}
return output, nil

}

func IsEmbeddingOutputInt(embeddingType string) bool {
return embeddingType == "int8" || embeddingType == "uint8" || embeddingType == "binary" || embeddingType == "ubinary"
}

func processWithIntOutput(e *execution, inputStruct EmbeddingInput) (tokenCount int, embedding []int, err error) {
req := cohereSDK.EmbedRequest{
Texts: []string{inputStruct.Text},
Model: &inputStruct.ModelName,
InputType: (*cohereSDK.EmbedInputType)(&inputStruct.InputType),
EmbeddingTypes: []cohereSDK.EmbeddingType{cohereSDK.EmbeddingType(inputStruct.EmbeddingType)},
}
resp, err := e.client.generateEmbedding(req)

if err != nil {
return 0, nil, err
}

embeddingResult, err := getIntEmbedding(resp, inputStruct.EmbeddingType)
if err != nil {
return 0, nil, err
}
return getBillingTokens(resp, inputStruct.EmbeddingType), embeddingResult, nil
}

func processWithFloatOutput(e *execution, inputStruct EmbeddingInput) (tokenCount int, embedding []float64, err error) {
embeddingTypeArray := []cohereSDK.EmbeddingType{}
switch inputStruct.EmbeddingType {
case "float":
if inputStruct.EmbeddingType == "float" {
embeddingTypeArray = append(embeddingTypeArray, cohereSDK.EmbeddingTypeFloat)
case "int8":
embeddingTypeArray = append(embeddingTypeArray, cohereSDK.EmbeddingTypeInt8)
case "uint8":
embeddingTypeArray = append(embeddingTypeArray, cohereSDK.EmbeddingTypeUint8)
case "binary":
embeddingTypeArray = append(embeddingTypeArray, cohereSDK.EmbeddingTypeBinary)
case "ubinary":
embeddingTypeArray = append(embeddingTypeArray, cohereSDK.EmbeddingTypeUbinary)
}
req := cohereSDK.EmbedRequest{
Texts: []string{inputStruct.Text},
Expand All @@ -56,58 +109,43 @@ func (e *execution) taskEmbedding(in *structpb.Struct) (*structpb.Struct, error)
EmbeddingTypes: embeddingTypeArray,
}
resp, err := e.client.generateEmbedding(req)

if err != nil {
return nil, err
return 0, nil, err
}

switch inputStruct.EmbeddingType {
case "int8", "uint8", "binary", "ubinary":
bills := resp.EmbeddingsByType.Meta.BilledUnits
outputStruct := embeddingIntOutput{
Usage: embedUsage{
Tokens: int(*bills.InputTokens),
},
}
switch inputStruct.EmbeddingType {
case "int8":
outputStruct.Embedding = resp.EmbeddingsByType.Embeddings.Int8[0]
case "uint8":
outputStruct.Embedding = resp.EmbeddingsByType.Embeddings.Uint8[0]
case "binary":
outputStruct.Embedding = resp.EmbeddingsByType.Embeddings.Binary[0]
case "ubinary":
outputStruct.Embedding = resp.EmbeddingsByType.Embeddings.Ubinary[0]
}
output, err := base.ConvertToStructpb(outputStruct)
if err != nil {
return nil, err
}
return output, nil
case "float":
bills := resp.EmbeddingsByType.Meta.BilledUnits
outputStruct := embeddingFloatOutput{
Usage: embedUsage{
Tokens: int(*bills.InputTokens),
},
Embedding: resp.EmbeddingsByType.Embeddings.Float[0],
}
output, err := base.ConvertToStructpb(outputStruct)
if err != nil {
return nil, err
}
return output, nil
default:
bills := resp.EmbeddingsFloats.Meta.BilledUnits
outputStruct := embeddingFloatOutput{
Usage: embedUsage{
Tokens: int(*bills.InputTokens),
},
Embedding: resp.EmbeddingsFloats.Embeddings[0],
}
output, err := base.ConvertToStructpb(outputStruct)
if err != nil {
return nil, err
}
return output, nil
embeddingResult := getFloatEmbedding(resp, inputStruct.EmbeddingType)

return getBillingTokens(resp, inputStruct.EmbeddingType), embeddingResult, nil

}

func getIntEmbedding(resp cohereSDK.EmbedResponse, embeddingType string) ([]int, error) {
switch embeddingType {
case "int8":
return resp.EmbeddingsByType.Embeddings.Int8[0], nil
case "uint8":
return resp.EmbeddingsByType.Embeddings.Uint8[0], nil
case "binary":
return resp.EmbeddingsByType.Embeddings.Binary[0], nil
case "ubinary":
return resp.EmbeddingsByType.Embeddings.Ubinary[0], nil
}
return nil, fmt.Errorf("invalid embedding type: %s", embeddingType)
}

func getFloatEmbedding(resp cohereSDK.EmbedResponse, embeddingType string) []float64 {
if embeddingType == "float" {
return resp.EmbeddingsByType.Embeddings.Float[0]
} else {
return resp.EmbeddingsFloats.Embeddings[0]
}
}

func getBillingTokens(resp cohereSDK.EmbedResponse, embeddingType string) int {
if IsEmbeddingOutputInt(embeddingType) || embeddingType == "float" {
return int(*resp.EmbeddingsByType.Meta.BilledUnits.InputTokens)
} else {
return int(*resp.EmbeddingsFloats.Meta.BilledUnits.InputTokens)
}
}
2 changes: 1 addition & 1 deletion ai/cohere/v0/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (c *component) CreateExecution(sysVars map[string]any, setup *structpb.Stru
}
e := &execution{
ComponentExecution: base.ComponentExecution{Component: c, SystemVariables: sysVars, Task: task, Setup: resolvedSetup},
client: newClient(getAPIKey(setup), c.GetLogger()),
client: newClient(getAPIKey(resolvedSetup), c.GetLogger()),
usesInstillCredentials: resolved,
}
switch task {
Expand Down
8 changes: 4 additions & 4 deletions ai/cohere/v0/rerank.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ import (
"google.golang.org/protobuf/types/known/structpb"
)

type rerankInput struct {
type RerankInput struct {
Query string `json:"query"`
Documents []string `json:"documents"`
ModelName string `json:"model-name"`
}

type rerankOutput struct {
type RerankOutput struct {
Ranking []string `json:"ranking"`
Usage rerankUsage `json:"usage"`
Relevance []float64 `json:"relevance"`
Expand All @@ -26,7 +26,7 @@ type rerankUsage struct {

func (e *execution) taskRerank(in *structpb.Struct) (*structpb.Struct, error) {

inputStruct := rerankInput{}
inputStruct := RerankInput{}
err := base.ConvertFromStructpb(in, &inputStruct)
if err != nil {
return nil, fmt.Errorf("error generating input struct: %v", err)
Expand Down Expand Up @@ -61,7 +61,7 @@ func (e *execution) taskRerank(in *structpb.Struct) (*structpb.Struct, error) {
}
bills := resp.Meta.BilledUnits

outputStruct := rerankOutput{
outputStruct := RerankOutput{
Ranking: newRanking,
Usage: rerankUsage{Search: int(*bills.SearchUnits)},
Relevance: relevance,
Expand Down
8 changes: 4 additions & 4 deletions ai/cohere/v0/text_generation.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type MultiModalContent struct {
Type string `json:"type"`
}

type textGenerationInput struct {
type TextGenerationInput struct {
ChatHistory []ChatMessage `json:"chat-history"`
MaxNewTokens int `json:"max-new-tokens"`
ModelName string `json:"model-name"`
Expand All @@ -46,15 +46,15 @@ type commandUsage struct {
OutputTokens int `json:"output-tokens"`
}

type textGenerationOutput struct {
type TextGenerationOutput struct {
Text string `json:"text"`
Citations []citation `json:"citations"`
Usage commandUsage `json:"usage"`
}

func (e *execution) taskTextGeneration(in *structpb.Struct) (*structpb.Struct, error) {

inputStruct := textGenerationInput{}
inputStruct := TextGenerationInput{}
err := base.ConvertFromStructpb(in, &inputStruct)
if err != nil {
return nil, fmt.Errorf("error generating input struct: %v", err)
Expand Down Expand Up @@ -123,7 +123,7 @@ func (e *execution) taskTextGeneration(in *structpb.Struct) (*structpb.Struct, e
inputTokens := *bills.InputTokens
outputTokens := *bills.OutputTokens

outputStruct := textGenerationOutput{
outputStruct := TextGenerationOutput{
Text: resp.Text,
Citations: citations,
Usage: commandUsage{
Expand Down
Loading

0 comments on commit d850b8b

Please sign in to comment.