diff --git a/ai/ai21labs/v0/main.go b/ai/ai21labs/v0/main.go index 1a398810..a66f53e6 100644 --- a/ai/ai21labs/v0/main.go +++ b/ai/ai21labs/v0/main.go @@ -73,15 +73,15 @@ func (c *component) CreateExecution(sysVars map[string]any, setup *structpb.Stru } taskMap := map[string]func(*structpb.Struct) (*structpb.Struct, error){ - "TASK_TEXT_GENERATION_CHAT": e.TaskTextGenerationChat, - "TASK_TEXT_EMBEDDINGS": e.TaskTextEmbeddings, - "TASK_CONTEXTUAL_ANSWERING": e.TaskContextualAnswering, - "TASK_TEXT_SUMMARIZATION": e.TaskTextSummarization, - "TASK_TEXT_SUMMARIZATION_SEGMENT": e.TaskTextSummarizationBySegment, - "TASK_TEXT_PARAPHRASING": e.TaskTextParaphrasing, - "TASK_GRAMMAR_CHECK": e.TaskGrammarCheck, - "TASK_TEXT_IMPROVEMENT": e.TaskTextImprovement, - "TASK_TEXT_SEGMENTATION": e.TaskTextSegmentation, + TaskTextGenerationChat: e.TaskTextGenerationChat, + TaskTextEmbeddings: e.TaskTextEmbeddings, + TaskContextualAnswering: e.TaskContextualAnswering, + TaskTextSummarization: e.TaskTextSummarization, + TaskTextSummarizationBySegment: e.TaskTextSummarizationBySegment, + TaskTextParaphrasing: e.TaskTextParaphrasing, + TaskGrammarCheck: e.TaskGrammarCheck, + TaskTextImprovement: e.TaskTextImprovement, + TaskTextSegmentation: e.TaskTextSegmentation, } if taskFunc, ok := taskMap[task]; ok { e.execute = taskFunc diff --git a/ai/ai21labs/v0/tasks.go b/ai/ai21labs/v0/tasks.go index 61c38e43..620917cd 100644 --- a/ai/ai21labs/v0/tasks.go +++ b/ai/ai21labs/v0/tasks.go @@ -8,6 +8,18 @@ import ( // pricing info: https://www.ai21.com/pricing // note: task specific models are billed based on API calls, not generated tokens +const ( + TaskTextGenerationChat = "TASK_TEXT_GENERATION_CHAT" + TaskContextualAnswering = "TASK_CONTEXTUAL_ANSWERING" + TaskTextEmbeddings = "TASK_TEXT_EMBEDDINGS" + TaskTextImprovement = "TASK_TEXT_IMPROVEMENT" + TaskTextParaphrasing = "TASK_TEXT_PARAPHRASING" + TaskTextSummarization = "TASK_TEXT_SUMMARIZATION" + TaskTextSummarizationBySegment = "TASK_TEXT_SUMMARIZATION_SEGMENT" + TaskTextSegmentation = "TASK_TEXT_SEGMENTATION" + TaskGrammarCheck = "TASK_GRAMMAR_CHECK" +) + type TaskTextGenerationChatInput struct { base.TemplateTextGenerationInput TopP float64 `json:"top-p"` diff --git a/ai/ai21labs/v0/tasks_test.go b/ai/ai21labs/v0/tasks_test.go index 1e7e1874..0e292b9e 100644 --- a/ai/ai21labs/v0/tasks_test.go +++ b/ai/ai21labs/v0/tasks_test.go @@ -174,7 +174,7 @@ func TestTasks(t *testing.T) { }) c.Assert(err, qt.IsNil) e := &execution{ - ComponentExecution: base.ComponentExecution{Component: connector, SystemVariables: nil, Setup: setup, Task: "TASK_TEXT_GENERATION_CHAT"}, + ComponentExecution: base.ComponentExecution{Component: connector, SystemVariables: nil, Setup: setup, Task: TaskTextGenerationChat}, client: &MockAI21labsClient{}, usesInstillCredentials: false, } @@ -192,4 +192,316 @@ func TestTasks(t *testing.T) { c.Check(wantJSON, qt.JSONEquals, got[0].AsMap()) }) + c.Run("ok - task embedding", func(c *qt.C) { + tc := struct { + input map[string]any + wantResp TaskTextEmbeddingsOutput + }{ + input: map[string]any{"text": "Hello World!"}, + wantResp: TaskTextEmbeddingsOutput{ + Embedding: []float32{0.1, 0.2, 0.3}, + Usage: base.EmbeddingTextModelUsage{ + Tokens: len("Hello World!") / 2, // IMPORTANT: The vendor's API does not return the actual token count, so we are using a dummy value here. + }, + }, + } + setup, err := structpb.NewStruct(map[string]any{ + "api-key": apiKey, + }) + c.Assert(err, qt.IsNil) + e := &execution{ + ComponentExecution: base.ComponentExecution{Component: connector, SystemVariables: nil, Setup: setup, Task: TaskTextEmbeddings}, + client: &MockAI21labsClient{}, + usesInstillCredentials: false, + } + e.execute = e.TaskTextEmbeddings + exec := &base.ExecutionWrapper{Execution: e} + + pbIn, err := base.ConvertToStructpb(tc.input) + c.Assert(err, qt.IsNil) + + got, err := exec.Execution.Execute(ctx, []*structpb.Struct{pbIn}) + c.Assert(err, qt.IsNil) + + wantJSON, err := json.Marshal(tc.wantResp) + c.Assert(err, qt.IsNil) + c.Check(wantJSON, qt.JSONEquals, got[0].AsMap()) + }) + + c.Run("ok - task contextual answers without answers", func(c *qt.C) { + + tc := struct { + input map[string]any + wantResp TaskContextualAnsweringOutput + }{ + input: map[string]any{"question": ""}, + wantResp: TaskContextualAnsweringOutput{ + Answer: "Not found", + AnswerInContext: false, + }, + } + setup, err := structpb.NewStruct(map[string]any{ + "api-key": apiKey, + }) + c.Assert(err, qt.IsNil) + e := &execution{ + ComponentExecution: base.ComponentExecution{Component: connector, SystemVariables: nil, Setup: setup, Task: TaskContextualAnswering}, + client: &MockAI21labsClient{}, + usesInstillCredentials: false, + } + e.execute = e.TaskContextualAnswering + exec := &base.ExecutionWrapper{Execution: e} + + pbIn, err := base.ConvertToStructpb(tc.input) + c.Assert(err, qt.IsNil) + + got, err := exec.Execution.Execute(ctx, []*structpb.Struct{pbIn}) + c.Assert(err, qt.IsNil) + + wantJSON, err := json.Marshal(tc.wantResp) + c.Assert(err, qt.IsNil) + c.Check(wantJSON, qt.JSONEquals, got[0].AsMap()) + + }) + + c.Run("ok - task contextual answers with answers", func(c *qt.C) { + + tc := struct { + input map[string]any + wantResp TaskContextualAnsweringOutput + }{ + input: map[string]any{"question": "How's the weather today?"}, + wantResp: TaskContextualAnsweringOutput{ + Answer: "How's the weather today?" + " is a question", + AnswerInContext: true, + }, + } + setup, err := structpb.NewStruct(map[string]any{ + "api-key": apiKey, + }) + c.Assert(err, qt.IsNil) + e := &execution{ + ComponentExecution: base.ComponentExecution{Component: connector, SystemVariables: nil, Setup: setup, Task: TaskContextualAnswering}, + client: &MockAI21labsClient{}, + usesInstillCredentials: false, + } + e.execute = e.TaskContextualAnswering + exec := &base.ExecutionWrapper{Execution: e} + + pbIn, err := base.ConvertToStructpb(tc.input) + c.Assert(err, qt.IsNil) + + got, err := exec.Execution.Execute(ctx, []*structpb.Struct{pbIn}) + c.Assert(err, qt.IsNil) + + wantJSON, err := json.Marshal(tc.wantResp) + c.Assert(err, qt.IsNil) + c.Check(wantJSON, qt.JSONEquals, got[0].AsMap()) + }) + + c.Run("ok - task text summarization", func(c *qt.C) { + tc := struct { + input map[string]any + wantResp TaskTextSummarizationOutput + }{ + input: map[string]any{}, + wantResp: TaskTextSummarizationOutput{ + Summary: "ABC", + }, + } + setup, err := structpb.NewStruct(map[string]any{ + "api-key": apiKey, + }) + c.Assert(err, qt.IsNil) + e := &execution{ + ComponentExecution: base.ComponentExecution{Component: connector, SystemVariables: nil, Setup: setup, Task: TaskTextSummarization}, + client: &MockAI21labsClient{}, + usesInstillCredentials: false, + } + e.execute = e.TaskTextSummarization + exec := &base.ExecutionWrapper{Execution: e} + + pbIn, err := base.ConvertToStructpb(tc.input) + c.Assert(err, qt.IsNil) + + got, err := exec.Execution.Execute(ctx, []*structpb.Struct{pbIn}) + c.Assert(err, qt.IsNil) + + wantJSON, err := json.Marshal(tc.wantResp) + c.Assert(err, qt.IsNil) + c.Check(wantJSON, qt.JSONEquals, got[0].AsMap()) + }) + + c.Run("ok - task text summarization by segment", func(c *qt.C) { + tc := struct { + input map[string]any + wantResp TaskTextSummarizationBySegmentOutput + }{ + input: map[string]any{}, + wantResp: TaskTextSummarizationBySegmentOutput{ + Summerizations: []string{"abc"}, + SegmentTexts: []string{"ABC"}, + SegmentHtmls: []string{"