Skip to content

Commit

Permalink
llm: do not silently fail for supplied, but invalid formats (ollama#8130
Browse files Browse the repository at this point in the history
)

Changes in ollama#8002 introduced fixes for bugs with mangling JSON Schemas.
It also fixed a bug where the server would silently fail when clients
requested invalid formats. It also, unfortunately, introduced a bug
where the server would reject requests with an empty format, which
should be allowed.

The change in ollama#8127 updated the code to allow the empty format, but also
reintroduced the regression where the server would silently fail when
the format was set, but invalid.

This commit fixes both regressions. The server does not reject the empty
format, but it does reject invalid formats. It also adds tests to help
us catch regressions in the future.

Also, the updated code provides a more detailed error message when a
client sends a non-empty, but invalid format, echoing the invalid format
in the response.

This commits also takes the opportunity to remove superfluous linter
checks.
  • Loading branch information
bmizerany authored Dec 17, 2024
1 parent 0f06a6d commit 87f0a49
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 29 deletions.
4 changes: 0 additions & 4 deletions .golangci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ linters:
- containedctx
- contextcheck
- errcheck
- exportloopref
- gci
- gocheckcompilerdirectives
- gofmt
- gofumpt
Expand All @@ -30,8 +28,6 @@ linters:
- wastedassign
- whitespace
linters-settings:
gci:
sections: [standard, default, localmodule]
staticcheck:
checks:
- all
Expand Down
58 changes: 33 additions & 25 deletions llm/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -674,21 +674,6 @@ type CompletionResponse struct {
}

func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
if err := s.sem.Acquire(ctx, 1); err != nil {
if errors.Is(err, context.Canceled) {
slog.Info("aborting completion request due to client closing the connection")
} else {
slog.Error("Failed to acquire semaphore", "error", err)
}
return err
}
defer s.sem.Release(1)

// put an upper limit on num_predict to avoid the model running on forever
if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
req.Options.NumPredict = 10 * s.options.NumCtx
}

request := map[string]any{
"prompt": req.Prompt,
"stream": true,
Expand All @@ -714,6 +699,39 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
"cache_prompt": true,
}

if len(req.Format) > 0 {
switch {
case bytes.Equal(req.Format, []byte(`""`)):
// fallthrough
case bytes.Equal(req.Format, []byte(`"json"`)):
request["grammar"] = grammarJSON
case bytes.HasPrefix(req.Format, []byte("{")):
// User provided a JSON schema
g := llama.SchemaToGrammar(req.Format)
if g == nil {
return fmt.Errorf("invalid JSON schema in format")
}
request["grammar"] = string(g)
default:
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema", req.Format)
}
}

if err := s.sem.Acquire(ctx, 1); err != nil {
if errors.Is(err, context.Canceled) {
slog.Info("aborting completion request due to client closing the connection")
} else {
slog.Error("Failed to acquire semaphore", "error", err)
}
return err
}
defer s.sem.Release(1)

// put an upper limit on num_predict to avoid the model running on forever
if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
req.Options.NumPredict = 10 * s.options.NumCtx
}

// Make sure the server is ready
status, err := s.getServerStatusRetry(ctx)
if err != nil {
Expand All @@ -722,16 +740,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("unexpected server status: %s", status.ToString())
}

if bytes.Equal(req.Format, []byte(`"json"`)) {
request["grammar"] = grammarJSON
} else if bytes.HasPrefix(req.Format, []byte("{")) {
g := llama.SchemaToGrammar(req.Format)
if g == nil {
return fmt.Errorf("invalid JSON schema in format")
}
request["grammar"] = string(g)
}

// Handling JSON marshaling with special characters unescaped.
buffer := &bytes.Buffer{}
enc := json.NewEncoder(buffer)
Expand Down
63 changes: 63 additions & 0 deletions llm/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package llm

import (
"context"
"errors"
"fmt"
"strings"
"testing"

"github.com/ollama/ollama/api"
"golang.org/x/sync/semaphore"
)

func TestLLMServerCompletionFormat(t *testing.T) {
// This test was written to fix an already deployed issue. It is a bit
// of a mess, and but it's good enough, until we can refactoring the
// Completion method to be more testable.

ctx, cancel := context.WithCancel(context.Background())
s := &llmServer{
sem: semaphore.NewWeighted(1), // required to prevent nil panic
}

checkInvalid := func(format string) {
t.Helper()
err := s.Completion(ctx, CompletionRequest{
Options: new(api.Options),
Format: []byte(format),
}, nil)

want := fmt.Sprintf("invalid format: %q; expected \"json\" or a valid JSON Schema", format)
if err == nil || !strings.Contains(err.Error(), want) {
t.Fatalf("err = %v; want %q", err, want)
}
}

checkInvalid("X") // invalid format
checkInvalid(`"X"`) // invalid JSON Schema

cancel() // prevent further processing if request makes it past the format check

checkCanceled := func(err error) {
t.Helper()
if !errors.Is(err, context.Canceled) {
t.Fatalf("Completion: err = %v; expected context.Canceled", err)
}
}

valids := []string{`"json"`, `{"type":"object"}`, ``, `""`}
for _, valid := range valids {
err := s.Completion(ctx, CompletionRequest{
Options: new(api.Options),
Format: []byte(valid),
}, nil)
checkCanceled(err)
}

err := s.Completion(ctx, CompletionRequest{
Options: new(api.Options),
Format: nil, // missing format
}, nil)
checkCanceled(err)
}

0 comments on commit 87f0a49

Please sign in to comment.