Skip to content

Commit

Permalink
fix: fix testcases
Browse files Browse the repository at this point in the history
  • Loading branch information
namwoam committed Jul 4, 2024
1 parent 8d35bdb commit d25291a
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions ai/cohere/v0/component_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func TestComponent_Generation(t *testing.T) {
wantResp textGenerationOutput
}{
input: map[string]any{"model-name": "command-r-plus"},
wantResp: textGenerationOutput{Text: "Hi! My name is command-r-plus.", Ciatations: []ciatation{}},
wantResp: textGenerationOutput{Text: "Hi! My name is command-r-plus.", Ciatations: []ciatation{}, Usage: commandUsage{InputTokens: 20, OutputTokens: 30}},
}

c.Run("ok - task command", func(c *qt.C) {
Expand Down Expand Up @@ -83,7 +83,7 @@ func TestComponent_Generation(t *testing.T) {
wantResp embeddingOutput
}{
input: map[string]any{"text": "abcde"},
wantResp: embeddingOutput{Embedding: []float64{0.1, 0.2, 0.3, 0.4, 0.5}},
wantResp: embeddingOutput{Embedding: []float64{0.1, 0.2, 0.3, 0.4, 0.5}, Usage: embedUsage{Tokens: 20}},
}

c.Run("ok - task embed", func(c *qt.C) {
Expand Down Expand Up @@ -114,7 +114,7 @@ func TestComponent_Generation(t *testing.T) {
wantResp rerankOutput
}{
input: map[string]any{"documents": []string{"a", "b", "c", "d"}},
wantResp: rerankOutput{Ranking: []string{"d", "c", "b", "a"}},
wantResp: rerankOutput{Ranking: []string{"d", "c", "b", "a"}, Usage: rerankUsage{Search: 5}},
}
c.Run("ok - task rerank", func(c *qt.C) {
setup, err := structpb.NewStruct(map[string]any{
Expand Down Expand Up @@ -146,15 +146,24 @@ type MockCohereClient struct{}
func (m *MockCohereClient) generateTextChat(request cohereSDK.ChatRequest) (cohereSDK.NonStreamedChatResponse, error) {
tx := fmt.Sprintf("Hi! My name is %s.", *request.Model)
cia := []*cohereSDK.ChatCitation{}
inputToken := float64(20)
outputToken := float64(30)
bill := cohereSDK.ApiMetaBilledUnits{InputTokens: &inputToken, OutputTokens: &outputToken}
meta := cohereSDK.ApiMeta{BilledUnits: &bill}
return cohereSDK.NonStreamedChatResponse{
Citations: cia,
Text: tx,
Meta: &meta,
}, nil
}

func (m *MockCohereClient) generateEmbedding(request cohereSDK.EmbedRequest) (cohereSDK.EmbedResponse, error) {
inputToken := float64(20)
bill := cohereSDK.ApiMetaBilledUnits{InputTokens: &inputToken}
meta := cohereSDK.ApiMeta{BilledUnits: &bill}
embedding := cohereSDK.EmbedFloatsResponse{
Embeddings: [][]float64{{0.1, 0.2, 0.3, 0.4, 0.5}},
Meta: &meta,
}
return cohereSDK.EmbedResponse{
EmbeddingsFloats: &embedding,
Expand All @@ -174,7 +183,11 @@ func (m *MockCohereClient) generateRerank(request cohereSDK.RerankRequest) (cohe
{Document: &documents[2]},
{Document: &documents[3]},
}
searchCnt := float64(5)
bill := cohereSDK.ApiMetaBilledUnits{SearchUnits: &searchCnt}
meta := cohereSDK.ApiMeta{BilledUnits: &bill}
return cohereSDK.RerankResponse{
Results: result,
Meta: &meta,
}, nil
}

0 comments on commit d25291a

Please sign in to comment.