Skip to content

Commit

Permalink
feat: add mocks for testing the llm function (#998)
Browse files Browse the repository at this point in the history
1. Add `mock.NewArgumentsContext()` for testing llm function.
2. Add `WriteRecord.LLMResult` for retrieving llm function written data.
  • Loading branch information
woorui authored Jan 29, 2025
1 parent a69b2dd commit 8888536
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 34 deletions.
16 changes: 6 additions & 10 deletions cli/template/go/init_llm_test.tmpl
Original file line number Diff line number Diff line change
@@ -1,36 +1,32 @@
package main

import (
"fmt"
"reflect"
"testing"

"github.com/yomorun/yomo/ai"
"github.com/yomorun/yomo/serverless/mock"
)

func TestHandler(t *testing.T) {
tests := []struct {
name string
ctx *mock.MockContext
// want is the expected data and tag that be written by ctx.Write
want []mock.WriteRecord
// want is the expected result written by ctx.WriteLLMResult()
want string
}{
{
name: "get weather",
ctx: mock.NewMockContext([]byte(`{"arguments":"{\"city\":\"New York\",\"latitude\":40.7128,\"longitude\":-74.0060}"}`), 0x33),
want: []mock.WriteRecord{
{Data: []byte(`{"result":"The current weather in New York (40.712800,-74.006000) is sunny","arguments":"{\"city\":\"New York\",\"latitude\":40.7128,\"longitude\":-74.0060}","is_ok":true}`), Tag: ai.ReducerTag},
},
ctx: mock.NewArgumentsContext(`{"city":"New York","latitude":40.7128,"longitude":-74.0060}`, 0x33),
want: "The current weather in New York (40.712800,-74.006000) is sunny",
},
// TODO: add more test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Handler(tt.ctx)
got := tt.ctx.RecordsWritten()

fmt.Println(string(got[0].Data))
records := tt.ctx.RecordsWritten()
got := records[0].LLMResult

if !reflect.DeepEqual(got, tt.want) {
t.Errorf("TestHandler got: %v, want: %v", got, tt.want)
Expand Down
47 changes: 28 additions & 19 deletions cli/template/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,40 @@ var (

// get template content
func GetContent(command string, sfnType string, lang string, isTest bool) ([]byte, error) {
name, err := getTemplateFileName(command, sfnType, lang, isTest)
if err != nil {
return nil, err
}
f, err := fs.Open(name)
if err != nil {
if os.IsNotExist(err) {
if isTest {
return nil, ErrUnsupportedTest
}
return nil, err

Check warning on line 38 in cli/template/template.go

View check run for this annotation

Codecov / codecov/patch

cli/template/template.go#L27-L38

Added lines #L27 - L38 were not covered by tests
}
return nil, err

Check warning on line 40 in cli/template/template.go

View check run for this annotation

Codecov / codecov/patch

cli/template/template.go#L40

Added line #L40 was not covered by tests
}
defer f.Close()
_, err = f.Stat()
if err != nil {
return nil, err
}

Check warning on line 46 in cli/template/template.go

View check run for this annotation

Codecov / codecov/patch

cli/template/template.go#L42-L46

Added lines #L42 - L46 were not covered by tests

return fs.ReadFile(name)

Check warning on line 48 in cli/template/template.go

View check run for this annotation

Codecov / codecov/patch

cli/template/template.go#L48

Added line #L48 was not covered by tests
}

func getTemplateFileName(command string, sfnType string, lang string, isTest bool) (string, error) {
if command == "" {
command = "init"
}
sfnType, err := validateSfnType(sfnType)
if err != nil {
return nil, err
return "", err
}
lang, err = validateLang(lang)
if err != nil {
return nil, err
return "", err
}
sb := new(strings.Builder)
sb.WriteString(lang)
Expand All @@ -47,25 +71,10 @@ func GetContent(command string, sfnType string, lang string, isTest bool) ([]byt
}
sb.WriteString(".tmpl")

// valdiate the path exists
// validate the path exists
name := sb.String()
f, err := fs.Open(name)
if err != nil {
if os.IsNotExist(err) {
if isTest {
return nil, ErrUnsupportedTest
}
return nil, ErrUnsupportedFeature
}
return nil, err
}
defer f.Close()
_, err = f.Stat()
if err != nil {
return nil, err
}

return fs.ReadFile(name)
return name, nil
}

func validateSfnType(sfnType string) (string, error) {
Expand Down
121 changes: 121 additions & 0 deletions cli/template/template_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package template

import (
"testing"
)

func TestGetTemplateFileName(t *testing.T) {
type args struct {
command string
sfnType string
lang string
isTest bool
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{
name: "init_llm_go",
args: args{
command: "init",
sfnType: "llm",
lang: "go",
isTest: false,
},
want: "go/init_llm.tmpl",
wantErr: false,
},
{
name: "init_normal_node_test",
args: args{
command: "init",
sfnType: "normal",
lang: "node",
isTest: true,
},
want: "node/init_normal_test.tmpl",
wantErr: false,
},
{
name: "default_command_llm_go",
args: args{
command: "",
sfnType: "llm",
lang: "go",
isTest: false,
},
want: "go/init_llm.tmpl",
wantErr: false,
},
{
name: "unsupported_sfnType",
args: args{
command: "init",
sfnType: "unsupported",
lang: "go",
isTest: false,
},
want: "",
wantErr: true,
},
{
name: "unsupported_lang",
args: args{
command: "init",
sfnType: "llm",
lang: "unsupported",
isTest: false,
},
want: "",
wantErr: true,
},
{
name: "default_sfnType",
args: args{
command: "init",
sfnType: "",
lang: "go",
isTest: false,
},
want: "go/init_llm.tmpl",
wantErr: false,
},
{
name: "default_lang",
args: args{
command: "init",
sfnType: "llm",
lang: "",
isTest: false,
},
want: "go/init_llm.tmpl",
wantErr: false,
},
{
name: "default_sfnType_and_lang",
args: args{
command: "init",
sfnType: "",
lang: "",
isTest: false,
},
want: "go/init_llm.tmpl",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := getTemplateFileName(tt.args.command, tt.args.sfnType, tt.args.lang, tt.args.isTest)
if (err != nil) != tt.wantErr {
t.Errorf("getTemplateFileName() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("getTemplateFileName() = %v, want %v", got, tt.want)
}
})
}
}
31 changes: 26 additions & 5 deletions serverless/mock/mock_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@ import (
"sync"

"github.com/yomorun/yomo/ai"
"github.com/yomorun/yomo/pkg/id"
"github.com/yomorun/yomo/serverless"
)

var _ serverless.Context = (*MockContext)(nil)

// WriteRecord composes the data, tag and target.
type WriteRecord struct {
Data []byte
Tag uint32
Target string
Data []byte
Tag uint32
Target string
LLMResult string
}

// MockContext mock context.
Expand All @@ -38,6 +40,23 @@ func NewMockContext(data []byte, tag uint32) *MockContext {
}
}

// NewArgumentsContext creates a Context with the provided arguments and tag.
// This function is used for testing the LLM function.
func NewArgumentsContext(arguments string, tag uint32) *MockContext {
fnCall := &ai.FunctionCall{
Arguments: arguments,
ReqID: id.New(16),
ToolCallID: "chatcmpl-" + id.New(29),
}
data, _ := fnCall.Bytes()

return &MockContext{
data: data,
tag: tag,
fnCall: fnCall,
}

Check warning on line 57 in serverless/mock/mock_context.go

View check run for this annotation

Codecov / codecov/patch

serverless/mock/mock_context.go#L45-L57

Added lines #L45 - L57 were not covered by tests
}

// Data incoming data.
func (c *MockContext) Data() []byte {
return c.data
Expand Down Expand Up @@ -123,9 +142,11 @@ func (c *MockContext) WriteLLMResult(result string) error {
}

c.wrSlice = append(c.wrSlice, WriteRecord{
Data: buf,
Tag: ai.ReducerTag,
Data: buf,
Tag: ai.ReducerTag,
LLMResult: result,
})

return nil
}

Expand Down

0 comments on commit 8888536

Please sign in to comment.