Skip to content
This repository has been archived by the owner on Oct 29, 2024. It is now read-only.

Commit

Permalink
feat(openai): use batch inference for embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
donch1989 committed Sep 27, 2024
1 parent 162ddad commit 63ad8ae
Showing 1 changed file with 93 additions and 58 deletions.
151 changes: 93 additions & 58 deletions ai/openai/v0/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,51 +379,6 @@ func (e *execution) worker(ctx context.Context, client *httpclient.Client, job *
}
}

case TextEmbeddingsTask:
inputStruct := TextEmbeddingsInput{}
err := base.ConvertFromStructpb(input, &inputStruct)
if err != nil {
job.Error.Error(ctx, err)
return
}

resp := TextEmbeddingsResp{}

var reqParams TextEmbeddingsReq
if inputStruct.Dimensions == 0 {
reqParams = TextEmbeddingsReq{
Model: inputStruct.Model,
Input: []string{inputStruct.Text},
}
} else {
reqParams = TextEmbeddingsReq{
Model: inputStruct.Model,
Input: []string{inputStruct.Text},
Dimensions: inputStruct.Dimensions,
}
}

req := client.R().SetBody(reqParams).SetResult(&resp)
if _, err := req.Post(embeddingsPath); err != nil {
job.Error.Error(ctx, err)
return
}

outputStruct := TextEmbeddingsOutput{
Embedding: resp.Data[0].Embedding,
}

output, err := base.ConvertToStructpb(outputStruct)
if err != nil {
job.Error.Error(ctx, err)
return
}
err = job.Output.Write(ctx, output)
if err != nil {
job.Error.Error(ctx, err)
return
}

case SpeechRecognitionTask:
inputStruct := AudioTranscriptionInput{}
err := base.ConvertFromStructpb(input, &inputStruct)
Expand Down Expand Up @@ -573,25 +528,105 @@ func chunk(items []*base.Job, batchSize int) (chunks [][]*base.Job) {
return append(chunks, items)
}

func (e *execution) executeEmbedding(ctx context.Context, client *httpclient.Client, jobs []*base.Job) {

texts := make([]string, len(jobs))
dimensions := 0
model := ""
for idx, job := range jobs {
input, err := job.Input.Read(ctx)
if err != nil {
job.Error.Error(ctx, err)
return
}
inputStruct := TextEmbeddingsInput{}
err = base.ConvertFromStructpb(input, &inputStruct)
if err != nil {
job.Error.Error(ctx, err)
return
}
texts[idx] = inputStruct.Text

if idx == 0 {
// Note: Currently, we assume that all data in the batch uses the same
// model and dimension settings. We need to add a check for this:
// if the model or dimensions differ, we should separate them into
// different inference groups.
dimensions = inputStruct.Dimensions
model = inputStruct.Model
}
}

resp := TextEmbeddingsResp{}

var reqParams TextEmbeddingsReq
if dimensions == 0 {
reqParams = TextEmbeddingsReq{
Model: model,
Input: texts,
}
} else {
reqParams = TextEmbeddingsReq{
Model: model,
Input: texts,
Dimensions: dimensions,
}
}

req := client.R().SetBody(reqParams).SetResult(&resp)
if _, err := req.Post(embeddingsPath); err != nil {
for _, job := range jobs {
job.Error.Error(ctx, err)
}
return
}

for idx, job := range jobs {
outputStruct := TextEmbeddingsOutput{
Embedding: resp.Data[idx].Embedding,
}
fmt.Println("idx", idx, resp.Data[idx].Embedding[:10])
output, err := base.ConvertToStructpb(outputStruct)
if err != nil {
job.Error.Error(ctx, err)
return
}
err = job.Output.Write(ctx, output)
if err != nil {
job.Error.Error(ctx, err)
return
}
}
}

func (e *execution) Execute(ctx context.Context, jobs []*base.Job) error {

client := newClient(e.Setup, e.GetLogger())
client.SetRetryCount(retryCount)
client.SetRetryWaitTime(1 * time.Second)

// TODO: we can encapsulate this code into a `ConcurrentExecutor`.
// The `ConcurrentExecutor` will use goroutines to execute jobs in parallel.
batchSize := 4
for _, batch := range chunk(jobs, batchSize) {
var wg sync.WaitGroup
wg.Add(len(batch))
for _, job := range batch {
go func() {
defer wg.Done()
e.worker(ctx, client, job)
}()
}
wg.Wait()
switch e.Task {
case TextEmbeddingsTask:
// OpenAI embedding API supports batch inference, so we'll leverage it
// directly for optimal performance.
e.executeEmbedding(ctx, client, jobs)

default:
// TODO: we can encapsulate this code into a `ConcurrentExecutor`.
// The `ConcurrentExecutor` will use goroutines to execute jobs in parallel.
batchSize := 4
for _, batch := range chunk(jobs, batchSize) {
var wg sync.WaitGroup
wg.Add(len(batch))
for _, job := range batch {
go func() {
defer wg.Done()
e.worker(ctx, client, job)
}()
}
wg.Wait()
}

}

return nil
Expand Down

0 comments on commit 63ad8ae

Please sign in to comment.