Skip to content
This repository has been archived by the owner on Oct 29, 2024. It is now read-only.

feat(openai): use batch inference for embedding #375

Merged
merged 1 commit into from
Sep 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 93 additions & 58 deletions ai/openai/v0/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,51 +379,6 @@ func (e *execution) worker(ctx context.Context, client *httpclient.Client, job *
}
}

case TextEmbeddingsTask:
inputStruct := TextEmbeddingsInput{}
err := base.ConvertFromStructpb(input, &inputStruct)
if err != nil {
job.Error.Error(ctx, err)
return
}

resp := TextEmbeddingsResp{}

var reqParams TextEmbeddingsReq
if inputStruct.Dimensions == 0 {
reqParams = TextEmbeddingsReq{
Model: inputStruct.Model,
Input: []string{inputStruct.Text},
}
} else {
reqParams = TextEmbeddingsReq{
Model: inputStruct.Model,
Input: []string{inputStruct.Text},
Dimensions: inputStruct.Dimensions,
}
}

req := client.R().SetBody(reqParams).SetResult(&resp)
if _, err := req.Post(embeddingsPath); err != nil {
job.Error.Error(ctx, err)
return
}

outputStruct := TextEmbeddingsOutput{
Embedding: resp.Data[0].Embedding,
}

output, err := base.ConvertToStructpb(outputStruct)
if err != nil {
job.Error.Error(ctx, err)
return
}
err = job.Output.Write(ctx, output)
if err != nil {
job.Error.Error(ctx, err)
return
}

case SpeechRecognitionTask:
inputStruct := AudioTranscriptionInput{}
err := base.ConvertFromStructpb(input, &inputStruct)
Expand Down Expand Up @@ -573,25 +528,105 @@ func chunk(items []*base.Job, batchSize int) (chunks [][]*base.Job) {
return append(chunks, items)
}

func (e *execution) executeEmbedding(ctx context.Context, client *httpclient.Client, jobs []*base.Job) {

texts := make([]string, len(jobs))
dimensions := 0
model := ""
for idx, job := range jobs {
input, err := job.Input.Read(ctx)
if err != nil {
job.Error.Error(ctx, err)
return
}
inputStruct := TextEmbeddingsInput{}
err = base.ConvertFromStructpb(input, &inputStruct)
if err != nil {
job.Error.Error(ctx, err)
return
}
texts[idx] = inputStruct.Text

if idx == 0 {
// Note: Currently, we assume that all data in the batch uses the same
// model and dimension settings. We need to add a check for this:
// if the model or dimensions differ, we should separate them into
// different inference groups.
dimensions = inputStruct.Dimensions
model = inputStruct.Model
}
}

resp := TextEmbeddingsResp{}

var reqParams TextEmbeddingsReq
if dimensions == 0 {
reqParams = TextEmbeddingsReq{
Model: model,
Input: texts,
}
} else {
reqParams = TextEmbeddingsReq{
Model: model,
Input: texts,
Dimensions: dimensions,
}
}

req := client.R().SetBody(reqParams).SetResult(&resp)
if _, err := req.Post(embeddingsPath); err != nil {
for _, job := range jobs {
job.Error.Error(ctx, err)
}
return
}

for idx, job := range jobs {
outputStruct := TextEmbeddingsOutput{
Embedding: resp.Data[idx].Embedding,
}
fmt.Println("idx", idx, resp.Data[idx].Embedding[:10])
output, err := base.ConvertToStructpb(outputStruct)
if err != nil {
job.Error.Error(ctx, err)
return
}
err = job.Output.Write(ctx, output)
if err != nil {
job.Error.Error(ctx, err)
return
}
}
}

func (e *execution) Execute(ctx context.Context, jobs []*base.Job) error {

client := newClient(e.Setup, e.GetLogger())
client.SetRetryCount(retryCount)
client.SetRetryWaitTime(1 * time.Second)

// TODO: we can encapsulate this code into a `ConcurrentExecutor`.
// The `ConcurrentExecutor` will use goroutines to execute jobs in parallel.
batchSize := 4
for _, batch := range chunk(jobs, batchSize) {
var wg sync.WaitGroup
wg.Add(len(batch))
for _, job := range batch {
go func() {
defer wg.Done()
e.worker(ctx, client, job)
}()
}
wg.Wait()
switch e.Task {
case TextEmbeddingsTask:
// OpenAI embedding API supports batch inference, so we'll leverage it
// directly for optimal performance.
e.executeEmbedding(ctx, client, jobs)

default:
// TODO: we can encapsulate this code into a `ConcurrentExecutor`.
// The `ConcurrentExecutor` will use goroutines to execute jobs in parallel.
batchSize := 4
for _, batch := range chunk(jobs, batchSize) {
var wg sync.WaitGroup
wg.Add(len(batch))
for _, job := range batch {
go func() {
defer wg.Done()
e.worker(ctx, client, job)
}()
}
wg.Wait()
}

}

return nil
Expand Down
Loading