diff --git a/ai/openai/v0/main.go b/ai/openai/v0/main.go index c4152338..48740b11 100644 --- a/ai/openai/v0/main.go +++ b/ai/openai/v0/main.go @@ -379,51 +379,6 @@ func (e *execution) worker(ctx context.Context, client *httpclient.Client, job * } } - case TextEmbeddingsTask: - inputStruct := TextEmbeddingsInput{} - err := base.ConvertFromStructpb(input, &inputStruct) - if err != nil { - job.Error.Error(ctx, err) - return - } - - resp := TextEmbeddingsResp{} - - var reqParams TextEmbeddingsReq - if inputStruct.Dimensions == 0 { - reqParams = TextEmbeddingsReq{ - Model: inputStruct.Model, - Input: []string{inputStruct.Text}, - } - } else { - reqParams = TextEmbeddingsReq{ - Model: inputStruct.Model, - Input: []string{inputStruct.Text}, - Dimensions: inputStruct.Dimensions, - } - } - - req := client.R().SetBody(reqParams).SetResult(&resp) - if _, err := req.Post(embeddingsPath); err != nil { - job.Error.Error(ctx, err) - return - } - - outputStruct := TextEmbeddingsOutput{ - Embedding: resp.Data[0].Embedding, - } - - output, err := base.ConvertToStructpb(outputStruct) - if err != nil { - job.Error.Error(ctx, err) - return - } - err = job.Output.Write(ctx, output) - if err != nil { - job.Error.Error(ctx, err) - return - } - case SpeechRecognitionTask: inputStruct := AudioTranscriptionInput{} err := base.ConvertFromStructpb(input, &inputStruct) @@ -573,25 +528,105 @@ func chunk(items []*base.Job, batchSize int) (chunks [][]*base.Job) { return append(chunks, items) } +func (e *execution) executeEmbedding(ctx context.Context, client *httpclient.Client, jobs []*base.Job) { + + texts := make([]string, len(jobs)) + dimensions := 0 + model := "" + for idx, job := range jobs { + input, err := job.Input.Read(ctx) + if err != nil { + job.Error.Error(ctx, err) + return + } + inputStruct := TextEmbeddingsInput{} + err = base.ConvertFromStructpb(input, &inputStruct) + if err != nil { + job.Error.Error(ctx, err) + return + } + texts[idx] = inputStruct.Text + + if idx == 0 { + // Note: Currently, we assume that all data in the batch uses the same + // model and dimension settings. We need to add a check for this: + // if the model or dimensions differ, we should separate them into + // different inference groups. + dimensions = inputStruct.Dimensions + model = inputStruct.Model + } + } + + resp := TextEmbeddingsResp{} + + var reqParams TextEmbeddingsReq + if dimensions == 0 { + reqParams = TextEmbeddingsReq{ + Model: model, + Input: texts, + } + } else { + reqParams = TextEmbeddingsReq{ + Model: model, + Input: texts, + Dimensions: dimensions, + } + } + + req := client.R().SetBody(reqParams).SetResult(&resp) + if _, err := req.Post(embeddingsPath); err != nil { + for _, job := range jobs { + job.Error.Error(ctx, err) + } + return + } + + for idx, job := range jobs { + outputStruct := TextEmbeddingsOutput{ + Embedding: resp.Data[idx].Embedding, + } + fmt.Println("idx", idx, resp.Data[idx].Embedding[:10]) + output, err := base.ConvertToStructpb(outputStruct) + if err != nil { + job.Error.Error(ctx, err) + return + } + err = job.Output.Write(ctx, output) + if err != nil { + job.Error.Error(ctx, err) + return + } + } +} + func (e *execution) Execute(ctx context.Context, jobs []*base.Job) error { client := newClient(e.Setup, e.GetLogger()) client.SetRetryCount(retryCount) client.SetRetryWaitTime(1 * time.Second) - // TODO: we can encapsulate this code into a `ConcurrentExecutor`. - // The `ConcurrentExecutor` will use goroutines to execute jobs in parallel. - batchSize := 4 - for _, batch := range chunk(jobs, batchSize) { - var wg sync.WaitGroup - wg.Add(len(batch)) - for _, job := range batch { - go func() { - defer wg.Done() - e.worker(ctx, client, job) - }() - } - wg.Wait() + switch e.Task { + case TextEmbeddingsTask: + // OpenAI embedding API supports batch inference, so we'll leverage it + // directly for optimal performance. + e.executeEmbedding(ctx, client, jobs) + + default: + // TODO: we can encapsulate this code into a `ConcurrentExecutor`. + // The `ConcurrentExecutor` will use goroutines to execute jobs in parallel. + batchSize := 4 + for _, batch := range chunk(jobs, batchSize) { + var wg sync.WaitGroup + wg.Add(len(batch)) + for _, job := range batch { + go func() { + defer wg.Done() + e.worker(ctx, client, job) + }() + } + wg.Wait() + } + } return nil