diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml new file mode 100644 index 0000000..412a450 --- /dev/null +++ b/.github/workflows/docker.yaml @@ -0,0 +1,80 @@ +name: Create Docker Image +on: + release: + types: + - created + + workflow_dispatch: + +jobs: + build: + name: Build + strategy: + matrix: + arch: [ amd64, arm64 ] + runs-on: + - ${{ matrix.arch == 'amd64' && 'ubuntu-latest' || matrix.arch }} + env: + OS: linux + ARCH: ${{ matrix.arch }} + DOCKER_REPO: ghcr.io/${{ github.repository }} + DOCKER_SOURCE: https://github.com/${{ github.repository }} + outputs: + tag: ${{ steps.build.outputs.tag }} + permissions: + contents: read + packages: write + steps: + - name: Install build tools + run: | + sudo apt -y update + sudo apt -y install build-essential git + git config --global advice.detachedHead false + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Login + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.repository_owner }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Build and Push + id: build + run: | + make docker && make docker-push && make docker-version >> "$GITHUB_OUTPUT" + manifest: + name: Manifest + needs: build + strategy: + matrix: + tag: + - ${{ needs.build.outputs.tag }} + - "latest" + runs-on: ubuntu-latest + permissions: + packages: write + steps: + - name: Login + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.repository_owner }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Create + run: | + docker manifest create ghcr.io/${{ github.repository }}:${{ matrix.tag }} \ + --amend ghcr.io/${{ github.repository }}-linux-amd64:${{ needs.build.outputs.tag }} \ + --amend ghcr.io/${{ github.repository }}-linux-arm64:${{ needs.build.outputs.tag }} + - name: Annotate + run: | + docker manifest annotate --arch amd64 --os linux \ + ghcr.io/${{ github.repository }}:${{ matrix.tag }} \ + ghcr.io/${{ github.repository }}-linux-amd64:${{ needs.build.outputs.tag }} + docker manifest annotate --arch arm64 --os linux \ + ghcr.io/${{ github.repository }}:${{ matrix.tag }} \ + ghcr.io/${{ github.repository }}-linux-arm64:${{ needs.build.outputs.tag }} + - name: Push + run: | + docker manifest push ghcr.io/${{ github.repository }}:${{ matrix.tag }} diff --git a/.gitignore b/.gitignore index 6f72f89..08ef92e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,25 +1,12 @@ -# If you prefer the allow list template instead of the deny list, see community template: -# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore -# -# Binaries for programs and plugins *.exe *.exe~ *.dll *.so *.dylib - -# Test binary, built with `go test -c` *.test - -# Output of the go coverage tool, specifically when used with LiteIDE *.out - -# Dependency directories (remove the comment below to include it) -# vendor/ - -# Go workspace file go.work go.work.sum - -# env file -.env +vendor/ +build/ +.DS_Store diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..8922c74 --- /dev/null +++ b/Makefile @@ -0,0 +1,118 @@ +# Executables +GO ?= $(shell which go 2>/dev/null) +DOCKER ?= $(shell which docker 2>/dev/null) + +# Locations +BUILD_DIR ?= build +CMD_DIR := $(wildcard cmd/*) + +# VERBOSE=1 +ifneq ($(VERBOSE),) + VERBOSE_FLAG = -v +else + VERBOSE_FLAG = +endif + +# Set OS and Architecture +ARCH ?= $(shell arch | tr A-Z a-z | sed 's/x86_64/amd64/' | sed 's/i386/amd64/' | sed 's/armv7l/arm/' | sed 's/aarch64/arm64/') +OS ?= $(shell uname | tr A-Z a-z) +VERSION ?= $(shell git describe --tags --always | sed 's/^v//') + +# Set build flags +BUILD_MODULE = $(shell cat go.mod | head -1 | cut -d ' ' -f 2) +BUILD_LD_FLAGS += -X $(BUILD_MODULE)/pkg/version.GitSource=${BUILD_MODULE} +BUILD_LD_FLAGS += -X $(BUILD_MODULE)/pkg/version.GitTag=$(shell git describe --tags --always) +BUILD_LD_FLAGS += -X $(BUILD_MODULE)/pkg/version.GitBranch=$(shell git name-rev HEAD --name-only --always) +BUILD_LD_FLAGS += -X $(BUILD_MODULE)/pkg/version.GitHash=$(shell git rev-parse HEAD) +BUILD_LD_FLAGS += -X $(BUILD_MODULE)/pkg/version.GoBuildTime=$(shell date -u '+%Y-%m-%dT%H:%M:%SZ') +BUILD_FLAGS = -ldflags "-s -w ${BUILD_LD_FLAGS}" + +# Docker +DOCKER_REPO ?= ghcr.io/mutablelogic/go-llm +DOCKER_SOURCE ?= ${BUILD_MODULE} +DOCKER_TAG = ${DOCKER_REPO}-${OS}-${ARCH}:${VERSION} + +############################################################################### +# ALL + +.PHONY: all +all: clean build + +############################################################################### +# BUILD + +# Build the commands in the cmd directory +.PHONY: build +build: tidy $(CMD_DIR) + +$(CMD_DIR): go-dep mkdir + @echo Build command $(notdir $@) GOOS=${OS} GOARCH=${ARCH} + @GOOS=${OS} GOARCH=${ARCH} ${GO} build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@ + +# Build the docker image +.PHONY: docker +docker: docker-dep + @echo build docker image ${DOCKER_TAG} OS=${OS} ARCH=${ARCH} SOURCE=${DOCKER_SOURCE} VERSION=${VERSION} + @${DOCKER} build \ + --tag ${DOCKER_TAG} \ + --build-arg ARCH=${ARCH} \ + --build-arg OS=${OS} \ + --build-arg SOURCE=${DOCKER_SOURCE} \ + --build-arg VERSION=${VERSION} \ + -f etc/docker/Dockerfile . + +# Push docker container +.PHONY: docker-push +docker-push: docker-dep + @echo push docker image: ${DOCKER_TAG} + @${DOCKER} push ${DOCKER_TAG} + +# Print out the version +.PHONY: docker-version +docker-version: docker-dep + @echo "tag=${VERSION}" + +############################################################################### +# TEST + +.PHONY: test +test: unit-test coverage-test + +.PHONY: unit-test +unit-test: go-dep + @echo Unit Tests + @${GO} test ${VERBOSE_FLAG} ./pkg/... + +.PHONY: coverage-test +coverage-test: go-dep mkdir + @echo Test Coverage + @${GO} test -coverprofile ${BUILD_DIR}/coverprofile.out ./pkg/... + +############################################################################### +# CLEAN + +.PHONY: tidy +tidy: + @echo Running go mod tidy + @${GO} mod tidy + +.PHONY: mkdir +mkdir: + @install -d ${BUILD_DIR} + +.PHONY: clean +clean: + @echo Clean + @rm -fr $(BUILD_DIR) + @${GO} clean + +############################################################################### +# DEPENDENCIES + +.PHONY: go-dep +go-dep: + @test -f "${GO}" && test -x "${GO}" || (echo "Missing go binary" && exit 1) + +.PHONY: docker-dep +docker-dep: + @test -f "${DOCKER}" && test -x "${DOCKER}" || (echo "Missing docker binary" && exit 1) \ No newline at end of file diff --git a/agent.go b/agent.go new file mode 100644 index 0000000..b7658dd --- /dev/null +++ b/agent.go @@ -0,0 +1,14 @@ +package llm + +import ( + "context" +) + +// An LLM Agent is a client for the LLM service +type Agent interface { + // Return the name of the agent + Name() string + + // Return the models + Models(context.Context) ([]Model, error) +} diff --git a/attachment.go b/attachment.go new file mode 100644 index 0000000..c7733c4 --- /dev/null +++ b/attachment.go @@ -0,0 +1,43 @@ +package llm + +import ( + "io" + "os" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +// Attachment for messages +type Attachment struct { + filename string + data []byte +} + +//////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +// ReadAttachment returns an attachment from a reader object. +// It is the responsibility of the caller to close the reader. +func ReadAttachment(r io.Reader) (*Attachment, error) { + var filename string + data, err := io.ReadAll(r) + if err != nil { + return nil, err + } + if f, ok := r.(*os.File); ok { + filename = f.Name() + } + return &Attachment{filename: filename, data: data}, nil +} + +//////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +func (a *Attachment) Filename() string { + return a.filename +} + +func (a *Attachment) Data() []byte { + return a.data +} diff --git a/cmd/agent/chat.go b/cmd/agent/chat.go new file mode 100644 index 0000000..4067d75 --- /dev/null +++ b/cmd/agent/chat.go @@ -0,0 +1,104 @@ +package main + +import ( + "context" + "errors" + "fmt" + "io" + "strings" + + // Packages + llm "github.com/mutablelogic/go-llm" + agent "github.com/mutablelogic/go-llm/pkg/agent" +) + +//////////////////////////////////////////////////////////////////////////////// +// TYPES + +type ChatCmd struct { + Model string `arg:"" help:"Model name"` + NoStream bool `flag:"nostream" help:"Disable streaming"` + System string `flag:"system" help:"Set the system prompt"` +} + +//////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +func (cmd *ChatCmd) Run(globals *Globals) error { + return runagent(globals, func(ctx context.Context, client llm.Agent) error { + // Get the model + a, ok := client.(*agent.Agent) + if !ok { + return fmt.Errorf("No agents found") + } + model, err := a.GetModel(ctx, cmd.Model) + if err != nil { + return err + } + + // Set the options + opts := []llm.Opt{} + if !cmd.NoStream { + opts = append(opts, llm.WithStream(func(cc llm.ContextContent) { + if text := cc.Text(); text != "" { + fmt.Println(text) + } + })) + } + if cmd.System != "" { + opts = append(opts, llm.WithSystemPrompt(cmd.System)) + } + if globals.toolkit != nil { + opts = append(opts, llm.WithToolKit(globals.toolkit)) + } + + // Create a session + session := model.Context(opts...) + + // Continue looping until end of input + for { + input, err := globals.term.ReadLine(model.Name() + "> ") + if errors.Is(err, io.EOF) { + return nil + } else if err != nil { + return err + } + + // Ignore empty input + input = strings.TrimSpace(input) + if input == "" { + continue + } + + // Feed input into the model + if err := session.FromUser(ctx, input); err != nil { + return err + } + + // Repeat call tools until no more calls are made + for { + calls := session.ToolCalls() + if len(calls) == 0 { + break + } + if session.Text() != "" { + globals.term.Println(session.Text()) + } else { + var names []string + for _, call := range calls { + names = append(names, call.Name()) + } + globals.term.Println("Calling ", strings.Join(names, ", ")) + } + if results, err := globals.toolkit.Run(ctx, calls...); err != nil { + return err + } else if err := session.FromTool(ctx, results...); err != nil { + return err + } + } + + // Print the response + globals.term.Println("\n" + session.Text() + "\n") + } + }) +} diff --git a/cmd/agent/main.go b/cmd/agent/main.go new file mode 100644 index 0000000..8d0970c --- /dev/null +++ b/cmd/agent/main.go @@ -0,0 +1,156 @@ +package main + +import ( + "context" + "os" + "os/signal" + "path/filepath" + "syscall" + + // Packages + kong "github.com/alecthomas/kong" + client "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" + agent "github.com/mutablelogic/go-llm/pkg/agent" + "github.com/mutablelogic/go-llm/pkg/newsapi" + "github.com/mutablelogic/go-llm/pkg/tool" +) + +//////////////////////////////////////////////////////////////////////////////// +// TYPES + +type Globals struct { + // Debugging + Debug bool `name:"debug" help:"Enable debug output"` + Verbose bool `name:"verbose" help:"Enable verbose output"` + + // Agents + Ollama `embed:"" help:"Ollama configuration"` + Anthropic `embed:"" help:"Anthropic configuration"` + + // Tools + NewsAPI `embed:"" help:"NewsAPI configuration"` + + // Context + ctx context.Context + agent llm.Agent + toolkit *tool.ToolKit + term *Term +} + +type Ollama struct { + OllamaEndpoint string `env:"OLLAMA_URL" help:"Ollama endpoint"` +} + +type Anthropic struct { + AnthropicKey string `env:"ANTHROPIC_API_KEY" help:"Anthropic API Key"` +} + +type NewsAPI struct { + NewsKey string `env:"NEWSAPI_KEY" help:"News API Key"` +} + +type CLI struct { + Globals + + // Agents, Models and Tools + Agents ListAgentsCmd `cmd:"" help:"Return a list of agents"` + Models ListModelsCmd `cmd:"" help:"Return a list of models"` + Tools ListToolsCmd `cmd:"" help:"Return a list of tools"` + + // Commands + Download DownloadModelCmd `cmd:"" help:"Download a model"` + Chat ChatCmd `cmd:"" help:"Start a chat session"` +} + +//////////////////////////////////////////////////////////////////////////////// +// MAIN + +func main() { + // Create a cli parser + cli := CLI{} + cmd := kong.Parse(&cli, + kong.Name(execName()), + kong.Description("LLM agent command line interface"), + kong.UsageOnError(), + kong.ConfigureHelp(kong.HelpOptions{Compact: true}), + kong.Vars{}, + ) + + // Create a context + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + cli.Globals.ctx = ctx + + // Create a terminal + term, err := NewTerm(os.Stdout) + if err != nil { + cmd.FatalIfErrorf(err) + return + } else { + cli.Globals.term = term + } + + // Client options + clientopts := []client.ClientOpt{} + if cli.Debug || cli.Verbose { + clientopts = append(clientopts, client.OptTrace(os.Stderr, cli.Verbose)) + } + + // Create an agent + opts := []llm.Opt{} + if cli.OllamaEndpoint != "" { + opts = append(opts, agent.WithOllama(cli.OllamaEndpoint, clientopts...)) + } + if cli.AnthropicKey != "" { + opts = append(opts, agent.WithAnthropic(cli.AnthropicKey, clientopts...)) + } + + // Make a toolkit + toolkit := tool.NewToolKit() + cli.Globals.toolkit = toolkit + + // Register NewsAPI + if cli.NewsKey != "" { + if client, err := newsapi.New(cli.NewsKey, clientopts...); err != nil { + cmd.FatalIfErrorf(err) + } else if err := client.RegisterWithToolKit(toolkit); err != nil { + cmd.FatalIfErrorf(err) + } + } + + // Append the toolkit + opts = append(opts, llm.WithToolKit(toolkit)) + + // Create the agent + agent, err := agent.New(opts...) + cmd.FatalIfErrorf(err) + cli.Globals.agent = agent + + // Run the command + if err := cmd.Run(&cli.Globals); err != nil { + cmd.FatalIfErrorf(err) + return + } +} + +//////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func execName() string { + // The name of the executable + name, err := os.Executable() + if err != nil { + panic(err) + } else { + return filepath.Base(name) + } +} + +func clientOpts(cli *CLI) []client.ClientOpt { + result := []client.ClientOpt{} + if cli.Debug { + result = append(result, client.OptTrace(os.Stderr, cli.Verbose)) + } + return result +} diff --git a/cmd/agent/models.go b/cmd/agent/models.go new file mode 100644 index 0000000..1bb96ee --- /dev/null +++ b/cmd/agent/models.go @@ -0,0 +1,117 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + + // Packages + llm "github.com/mutablelogic/go-llm" + agent "github.com/mutablelogic/go-llm/pkg/agent" + ollama "github.com/mutablelogic/go-llm/pkg/ollama" +) + +//////////////////////////////////////////////////////////////////////////////// +// TYPES + +type ListModelsCmd struct { + Agent []string `help:"Only return models from a specific agent"` +} + +type ListAgentsCmd struct{} + +type ListToolsCmd struct{} + +type DownloadModelCmd struct { + Agent string `arg:"" help:"Agent name"` + Model string `arg:"" help:"Model name"` +} + +//////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +func (cmd *ListToolsCmd) Run(globals *Globals) error { + return runagent(globals, func(ctx context.Context, client llm.Agent) error { + tools := globals.toolkit.Tools(client) + fmt.Println(tools) + return nil + }) +} + +func (cmd *ListModelsCmd) Run(globals *Globals) error { + return runagent(globals, func(ctx context.Context, client llm.Agent) error { + agent, ok := client.(*agent.Agent) + if !ok { + return fmt.Errorf("No agents found") + } + models, err := agent.ListModels(ctx, cmd.Agent...) + if err != nil { + return err + } + fmt.Println(models) + return nil + }) +} + +func (*ListAgentsCmd) Run(globals *Globals) error { + return runagent(globals, func(ctx context.Context, client llm.Agent) error { + agent, ok := client.(*agent.Agent) + if !ok { + return fmt.Errorf("No agents found") + } + + var agents []string + for _, agent := range agent.Agents() { + agents = append(agents, agent.Name()) + } + + data, err := json.MarshalIndent(agents, "", " ") + if err != nil { + return err + } + fmt.Println(string(data)) + + return nil + }) +} + +func (cmd *DownloadModelCmd) Run(globals *Globals) error { + return runagent(globals, func(ctx context.Context, client llm.Agent) error { + agent := getagent(client, cmd.Agent) + if agent == nil { + return fmt.Errorf("No agents found with name %q", cmd.Agent) + } + // Download the model + switch agent.Name() { + case "ollama": + model, err := agent.(*ollama.Client).PullModel(ctx, cmd.Model) + if err != nil { + return err + } + fmt.Println(model) + default: + return fmt.Errorf("Agent %q does not support model download", agent.Name()) + } + return nil + }) +} + +//////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func runagent(globals *Globals, fn func(ctx context.Context, agent llm.Agent) error) error { + return fn(globals.ctx, globals.agent) +} + +func getagent(client llm.Agent, name string) llm.Agent { + agent, ok := client.(*agent.Agent) + if !ok { + return nil + } + for _, agent := range agent.Agents() { + if agent.Name() == name { + return agent + } + } + return nil +} diff --git a/cmd/agent/term.go b/cmd/agent/term.go new file mode 100644 index 0000000..346a0e2 --- /dev/null +++ b/cmd/agent/term.go @@ -0,0 +1,79 @@ +package main + +import ( + "fmt" + "io" + "os" + + // Packages + format "github.com/MichaelMure/go-term-text" + color "github.com/fatih/color" + term "golang.org/x/term" +) + +type Term struct { + r io.Reader + fd int + *term.Terminal +} + +func NewTerm(r io.Reader) (*Term, error) { + t := new(Term) + t.r = r + + // Set file descriptor + if osf, ok := r.(*os.File); ok { + t.fd = int(osf.Fd()) + if term.IsTerminal(t.fd) { + t.Terminal = term.NewTerminal(osf, "") + } + } + + // Return success + return t, nil +} + +// Returns the width and height of the terminal, or (0,0) +func (t *Term) Size() (int, int) { + if t.Terminal != nil { + if w, h, err := term.GetSize(t.fd); err == nil { + return w, h + } + } + // Unable to get the size + return 0, 0 +} + +func (t *Term) Println(v ...any) { + text := fmt.Sprint(v...) + w, _ := t.Size() + if w > 0 { + text, _ = format.Wrap(text, w) + } + fmt.Fprintln(os.Stdout, text) +} + +func (t *Term) ReadLine(prompt string) (string, error) { + // Set terminal raw mode + if t.Terminal != nil { + state, err := term.MakeRaw(t.fd) + if err != nil { + return "", err + } + defer term.Restore(t.fd, state) + } + + // Set the prompt with color + if t.Terminal != nil { + prompt = color.New(color.Bold).Sprint(prompt) + t.Terminal.SetPrompt(prompt) + } + + // Read the line + if t.Terminal != nil { + return t.Terminal.ReadLine() + } else { + // Don't support non-terminal input yet + return "", io.EOF + } +} diff --git a/context.go b/context.go new file mode 100644 index 0000000..cc8e83c --- /dev/null +++ b/context.go @@ -0,0 +1,31 @@ +package llm + +import "context" + +////////////////////////////////////////////////////////////////// +// TYPES + +// ContextContent is the content of the last context message +type ContextContent interface { + // Return the current session role, which can be system, assistant, user, tool, tool_result, ... + Role() string + + // Return the current session text, or empty string if no text was returned + Text() string + + // Return the current session tool calls, or empty if no tool calls were made + ToolCalls() []ToolCall +} + +// Context is fed to the agent to generate a response +type Context interface { + ContextContent + + // Generate a response from a user prompt (with attachments and + // other options) + FromUser(context.Context, string, ...Opt) error + + // Generate a response from a tool, passing the results + // from the tool call + FromTool(context.Context, ...ToolResult) error +} diff --git a/error.go b/error.go new file mode 100644 index 0000000..4bb0c9d --- /dev/null +++ b/error.go @@ -0,0 +1,49 @@ +package llm + +import ( + "fmt" +) + +//////////////////////////////////////////////////////////////////////////////// +// GLOBALS + +const ( + ErrSuccess Err = iota + ErrNotFound + ErrBadParameter + ErrNotImplemented + ErrConflict +) + +//////////////////////////////////////////////////////////////////////////////// +// TYPES + +// Errors +type Err int + +//////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +func (e Err) Error() string { + switch e { + case ErrSuccess: + return "success" + case ErrNotFound: + return "not found" + case ErrBadParameter: + return "bad parameter" + case ErrNotImplemented: + return "not implemented" + case ErrConflict: + return "conflict" + } + return fmt.Sprintf("error code %d", int(e)) +} + +func (e Err) With(args ...interface{}) error { + return fmt.Errorf("%w: %s", e, fmt.Sprint(args...)) +} + +func (e Err) Withf(format string, args ...interface{}) error { + return fmt.Errorf("%w: %s", e, fmt.Sprintf(format, args...)) +} diff --git a/etc/docker/Dockerfile b/etc/docker/Dockerfile new file mode 100644 index 0000000..b612cda --- /dev/null +++ b/etc/docker/Dockerfile @@ -0,0 +1,28 @@ +ARG OS +ARG ARCH + +# Run makefile to build all the commands +FROM --platform=${OS}/${ARCH} golang:latest AS builder +ARG OS +ARG ARCH +WORKDIR /usr/src/app +COPY . . + +# Build the server +RUN \ + apt update -y && apt upgrade -y && \ + OS=${OS} ARCH=${ARCH} make build + +# Copy binaries to /usr/local/bin +FROM --platform=${OS}/${ARCH} debian:bookworm-slim +ARG OS +ARG ARCH +ARG SOURCE +COPY --from=builder /usr/src/app/build/* /usr/local/bin/ +RUN apt update -y && apt install -y ca-certificates + +# Labels +LABEL org.opencontainers.image.source=https://${SOURCE} + +# Entrypoint when running the server +ENTRYPOINT [ "/usr/local/bin/agent" ] diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..199ec9e --- /dev/null +++ b/go.mod @@ -0,0 +1,24 @@ +module github.com/mutablelogic/go-llm + +go 1.23.5 + +require ( + github.com/MichaelMure/go-term-text v0.3.1 + github.com/alecthomas/kong v1.7.0 + github.com/djthorpe/go-errors v1.0.3 + github.com/fatih/color v1.9.0 + github.com/mutablelogic/go-client v1.0.10 + github.com/stretchr/testify v1.10.0 + golang.org/x/term v0.28.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/mattn/go-colorable v0.1.4 // indirect + github.com/mattn/go-isatty v0.0.11 // indirect + github.com/mattn/go-runewidth v0.0.15 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rivo/uniseg v0.4.7 // indirect + golang.org/x/sys v0.29.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e9cc101 --- /dev/null +++ b/go.sum @@ -0,0 +1,48 @@ +github.com/MichaelMure/go-term-text v0.3.1 h1:Kw9kZanyZWiCHOYu9v/8pWEgDQ6UVN9/ix2Vd2zzWf0= +github.com/MichaelMure/go-term-text v0.3.1/go.mod h1:QgVjAEDUnRMlzpS6ky5CGblux7ebeiLnuy9dAaFZu8o= +github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0= +github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= +github.com/alecthomas/kong v1.7.0 h1:MnT8+5JxFDCvISeI6vgd/mFbAJwueJ/pqQNzZMsiqZE= +github.com/alecthomas/kong v1.7.0/go.mod h1:p2vqieVMeTAnaC83txKtXe8FLke2X07aruPWXyMPQrU= +github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc= +github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/djthorpe/go-errors v1.0.3 h1:GZeMPkC1mx2vteXLI/gvxZS0Ee9zxzwD1mcYyKU5jD0= +github.com/djthorpe/go-errors v1.0.3/go.mod h1:HtfrZnMd6HsX75Mtbv9Qcnn0BqOrrFArvCaj3RMnZhY= +github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s= +github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= +github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= +github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= +github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA= +github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.11 h1:FxPOTFNqGkuDUGi3H/qkUbQO4ZiBa2brKq5r0l8TGeM= +github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= +github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= +github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= +github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mutablelogic/go-client v1.0.10 h1:d4t8irXlGNQrQS/+FoUht+1RnjL9lBaf1e2UasN3ifE= +github.com/mutablelogic/go-client v1.0.10/go.mod h1:XbG8KGo2Efi7PGxXs7rhYxYhLeXL6aCSo6sz0mVchiw= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= +golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/model.go b/model.go new file mode 100644 index 0000000..221105b --- /dev/null +++ b/model.go @@ -0,0 +1,22 @@ +package llm + +import "context" + +// An Model can be used to generate a response to a user prompt, +// which is passed to an agent. The interaction occurs through +// a session context object. +type Model interface { + // Return the name of the model + Name() string + + // Return am empty session context object for the model, + // setting session options + Context(...Opt) Context + + // Convenience method to create a session context object + // with a user prompt + UserPrompt(string, ...Opt) Context + + // Embedding vector generation + Embedding(context.Context, string, ...Opt) ([]float64, error) +} diff --git a/opt.go b/opt.go new file mode 100644 index 0000000..df91705 --- /dev/null +++ b/opt.go @@ -0,0 +1,232 @@ +package llm + +import ( + "io" + "time" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +// A generic option type, which can set options on an agent or session +type Opt func(*Opts) error + +// set of options +type Opts struct { + agents map[string]Agent // Set of agents + toolkit ToolKit // Toolkit for tools + callback func(ContextContent) // Streaming callback + attachments []*Attachment // Attachments + system string // System prompt + options map[string]any // Additional options +} + +//////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +// ApplyOpts returns a structure of options +func ApplyOpts(opts ...Opt) (*Opts, error) { + o := new(Opts) + o.agents = make(map[string]Agent) + o.options = make(map[string]any) + for _, opt := range opts { + if err := opt(o); err != nil { + return nil, err + } + } + return o, nil +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - PROPERTIES + +// Return the set of tools +func (o *Opts) ToolKit() ToolKit { + return o.toolkit +} + +// Return the stream function +func (o *Opts) StreamFn() func(ContextContent) { + return o.callback +} + +// Return the system prompt +func (o *Opts) SystemPrompt() string { + return o.system +} + +// Return the array of registered agents +func (o *Opts) Agents() []Agent { + result := make([]Agent, 0, len(o.agents)) + for _, agent := range o.agents { + result = append(result, agent) + } + return result +} + +// Return attachments +func (o *Opts) Attachments() []*Attachment { + return o.attachments +} + +// Set an option value +func (o *Opts) Set(key string, value any) { + o.options[key] = value +} + +// Get an option value +func (o *Opts) Get(key string) any { + if value, exists := o.options[key]; exists { + return value + } + return nil +} + +// Has an option value +func (o *Opts) Has(key string) bool { + _, exists := o.options[key] + return exists +} + +// Get an option value as a string +func (o *Opts) GetString(key string) string { + if value, exists := o.options[key]; exists { + if v, ok := value.(string); ok { + return v + } + } + return "" +} + +// Get an option value as a boolean +func (o *Opts) GetBool(key string) bool { + if value, exists := o.options[key]; exists { + if v, ok := value.(bool); ok { + return v + } + } + return false +} + +// Get an option value as an unsigned integer +func (o *Opts) GetUint64(key string) uint64 { + if value, exists := o.options[key]; exists { + if v, ok := value.(uint64); ok { + return v + } + } + return 0 +} + +// Get an option value as a float64 +func (o *Opts) GetFloat64(key string) float64 { + if value, exists := o.options[key]; exists { + if v, ok := value.(float64); ok { + return v + } + } + return 0 +} + +// Get an option value as a duration +func (o *Opts) GetDuration(key string) time.Duration { + if value, exists := o.options[key]; exists { + if v, ok := value.(time.Duration); ok { + return v + } + } + return 0 +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS - SET OPTIONS + +// Set toolkit of tools +func WithToolKit(toolkit ToolKit) Opt { + return func(o *Opts) error { + o.toolkit = toolkit + return nil + } +} + +// Set chat streaming function +func WithStream(fn func(ContextContent)) Opt { + return func(o *Opts) error { + o.callback = fn + return nil + } +} + +// Set agent +func WithAgent(agent Agent) Opt { + return func(o *Opts) error { + // Check parameters + if agent == nil { + return ErrBadParameter.With("withAgent") + } + + // Add agent + name := agent.Name() + if _, exists := o.agents[name]; exists { + return ErrConflict.Withf("Agent %q already exists", name) + } else { + o.agents[name] = agent + } + + // Return success + return nil + } + +} + +// Create an attachment +func WithAttachment(r io.Reader) Opt { + return func(o *Opts) error { + if attachment, err := ReadAttachment(r); err != nil { + return err + } else { + o.attachments = append(o.attachments, attachment) + } + return nil + } +} + +// The temperature of the model. Increasing the temperature will make the model answer more creatively. +func WithTemperature(v float64) Opt { + return func(o *Opts) error { + if v < 0.0 || v > 1.0 { + return ErrBadParameter.With("temperature must be between 0.0 and 1.0") + } + o.Set("temperature", v) + return nil + } +} + +// Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while +// a lower value (e.g., 0.5) will generate more focused and conservative text. +func WithTopP(v float64) Opt { + return func(o *Opts) error { + if v < 0.0 || v > 1.0 { + return ErrBadParameter.With("top_p must be between 0.0 and 1.0") + } + o.Set("top_p", v) + return nil + } +} + +// Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more +// diverse answers, while a lower value (e.g. 10) will be more conservative. +func WithTopK(v uint) Opt { + return func(o *Opts) error { + o.Set("top_k", v) + return nil + } +} + +// Set system prompt +func WithSystemPrompt(v string) Opt { + return func(o *Opts) error { + o.system = v + return nil + } +} diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go new file mode 100644 index 0000000..312c327 --- /dev/null +++ b/pkg/agent/agent.go @@ -0,0 +1,173 @@ +package agent + +import ( + "context" + "encoding/json" + "errors" + "slices" + "strings" + + // Packages + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type Agent struct { + *llm.Opts +} + +type model struct { + Agent string `json:"agent"` + llm.Model `json:"model"` +} + +var _ llm.Agent = (*Agent)(nil) + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +// Return a new agent, composed of a series of agents and tools +func New(opts ...llm.Opt) (*Agent, error) { + agent := new(Agent) + if opts, err := llm.ApplyOpts(opts...); err != nil { + return nil, err + } else { + agent.Opts = opts + } + + // Return success + return agent, nil +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (m model) String() string { + data, err := json.MarshalIndent(m, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Return a list of tool names +func (a *Agent) ToolNames() []string { + if a.ToolKit() == nil { + return nil + } + var result []string + for _, t := range a.ToolKit().Tools(a) { + result = append(result, t.Name()) + } + return result +} + +// Return a list of agent names +func (a *Agent) AgentNames() []string { + var result []string + for _, a := range a.Agents() { + result = append(result, a.Name()) + } + return result +} + +// Return a list of agents +func (a *Agent) AgentsWithName(name ...string) []llm.Agent { + all := a.Agents() + if len(name) == 0 { + return all + } + result := make([]llm.Agent, 0, len(name)) + for _, a := range all { + if slices.Contains(name, a.Name()) { + result = append(result, a) + } + } + return result +} + +// Return a comma-separated list of agent names +func (a *Agent) Name() string { + var keys []string + for _, agent := range a.Agents() { + keys = append(keys, agent.Name()) + } + return strings.Join(keys, ",") +} + +// Return the models from all agents +func (a *Agent) Models(ctx context.Context) ([]llm.Model, error) { + return a.ListModels(ctx) +} + +// Return the models from list of agents +func (a *Agent) ListModels(ctx context.Context, names ...string) ([]llm.Model, error) { + var result error + + // Gather models from agents + agents := a.AgentsWithName(names...) + models := make([]llm.Model, 0, len(agents)*10) + for _, agent := range agents { + agentmodels, err := modelsForAgent(ctx, agent) + if err != nil { + result = errors.Join(result, err) + continue + } else { + models = append(models, agentmodels...) + } + } + + // Return the models with any errors + return models, result +} + +// Return a model by name. If no agents are specified, then all agents are considered. +// If multiple agents are specified, then the first model found is returned. +func (a *Agent) GetModel(ctx context.Context, name string, agentnames ...string) (llm.Model, error) { + var result error + + agents := a.AgentsWithName(agentnames...) + for _, agent := range agents { + models, err := modelsForAgent(ctx, agent, name) + if err != nil { + result = errors.Join(result, err) + continue + } else if len(models) > 0 { + return models[0], result + } + } + + // Return not found + result = errors.Join(result, llm.ErrNotFound.Withf("model %q", name)) + + // Return any errors + return nil, result +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func modelsForAgent(ctx context.Context, agent llm.Agent, names ...string) ([]llm.Model, error) { + // Gather models + models, err := agent.Models(ctx) + if err != nil { + return nil, err + } + + // Filter models + result := make([]llm.Model, 0, len(models)) + for _, agentmodel := range models { + if len(names) > 0 && !slices.Contains(names, agentmodel.Name()) { + continue + } + result = append(result, &model{Agent: agent.Name(), Model: agentmodel}) + } + + // Return success + return result, nil +} diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go new file mode 100644 index 0000000..5b78d81 --- /dev/null +++ b/pkg/agent/agent_test.go @@ -0,0 +1,45 @@ +package agent_test + +import ( + "context" + "os" + "testing" + + // Packages + llm "github.com/mutablelogic/go-llm" + agent "github.com/mutablelogic/go-llm/pkg/agent" + assert "github.com/stretchr/testify/assert" +) + +func Test_client_001(t *testing.T) { + assert := assert.New(t) + + opts := []llm.Opt{} + opts = append(opts, GetOllamaEndpoint(t)...) + + // Create a client + client, err := agent.New(opts...) + if assert.NoError(err) { + assert.NotNil(client) + + // Get models + models, err := client.Models(context.TODO()) + if !assert.NoError(err) { + t.FailNow() + } + assert.NotNil(models) + t.Log(models) + } +} + +/////////////////////////////////////////////////////////////////////////////// +// ENVIRONMENT + +func GetOllamaEndpoint(t *testing.T) []llm.Opt { + key := os.Getenv("OLLAMA_URL") + if key == "" { + return []llm.Opt{} + } else { + return []llm.Opt{agent.WithOllama(key)} + } +} diff --git a/pkg/agent/opt.go b/pkg/agent/opt.go new file mode 100644 index 0000000..a316881 --- /dev/null +++ b/pkg/agent/opt.go @@ -0,0 +1,34 @@ +package agent + +import ( + // Packages + client "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" + anthropic "github.com/mutablelogic/go-llm/pkg/anthropic" + ollama "github.com/mutablelogic/go-llm/pkg/ollama" +) + +//////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +func WithOllama(endpoint string, opts ...client.ClientOpt) llm.Opt { + return func(o *llm.Opts) error { + client, err := ollama.New(endpoint, opts...) + if err != nil { + return err + } else { + return llm.WithAgent(client)(o) + } + } +} + +func WithAnthropic(key string, opts ...client.ClientOpt) llm.Opt { + return func(o *llm.Opts) error { + client, err := anthropic.New(key, opts...) + if err != nil { + return err + } else { + return llm.WithAgent(client)(o) + } + } +} diff --git a/pkg/anthropic/client.go b/pkg/anthropic/client.go new file mode 100644 index 0000000..6f4a13b --- /dev/null +++ b/pkg/anthropic/client.go @@ -0,0 +1,53 @@ +/* +anthropic implements an API client for anthropic (https://docs.anthropic.com/en/api/getting-started) +*/ +package anthropic + +import ( + // Packages + client "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type Client struct { + *client.Client +} + +var _ llm.Agent = (*Client)(nil) + +/////////////////////////////////////////////////////////////////////////////// +// GLOBALS + +const ( + endPoint = "https://api.anthropic.com/v1" + defaultVersion = "2023-06-01" + defaultName = "anthropic" +) + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +// Create a new client +func New(ApiKey string, opts ...client.ClientOpt) (*Client, error) { + // Create client + opts = append(opts, client.OptEndpoint(endPoint)) + opts = append(opts, client.OptHeader("x-api-key", ApiKey), client.OptHeader("anthropic-version", defaultVersion)) + client, err := client.New(opts...) + if err != nil { + return nil, err + } + + // Return the client + return &Client{client}, nil +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Return the name of the agent +func (*Client) Name() string { + return defaultName +} diff --git a/pkg/anthropic/client_test.go b/pkg/anthropic/client_test.go new file mode 100644 index 0000000..ea7e2ee --- /dev/null +++ b/pkg/anthropic/client_test.go @@ -0,0 +1,32 @@ +package anthropic_test + +import ( + "os" + "testing" + + // Packages + opts "github.com/mutablelogic/go-client" + anthropic "github.com/mutablelogic/go-llm/pkg/anthropic" + assert "github.com/stretchr/testify/assert" +) + +func Test_client_001(t *testing.T) { + assert := assert.New(t) + client, err := anthropic.New(GetApiKey(t), opts.OptTrace(os.Stderr, true)) + if assert.NoError(err) { + assert.NotNil(client) + t.Log(client) + } +} + +/////////////////////////////////////////////////////////////////////////////// +// ENVIRONMENT + +func GetApiKey(t *testing.T) string { + key := os.Getenv("ANTHROPIC_API_KEY") + if key == "" { + t.Skip("ANTHROPIC_API_KEY not set, skipping tests") + t.SkipNow() + } + return key +} diff --git a/pkg/anthropic/message.go b/pkg/anthropic/message.go new file mode 100644 index 0000000..3a8acdd --- /dev/null +++ b/pkg/anthropic/message.go @@ -0,0 +1,193 @@ +package anthropic + +import ( + "encoding/json" + "net/http" + "strings" + + // Packages + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +// Message with text or object content +type MessageMeta struct { + Role string `json:"role"` + Content []*Content `json:"content,omitempty"` +} + +type Content struct { + Type string `json:"type"` // image, document, text, tool_use + ContentText + ContentAttachment + *ContentTool + ContentToolResult + CacheControl *cachecontrol `json:"cache_control,omitempty"` // ephemeral +} + +type ContentText struct { + Text string `json:"text,omitempty"` // text content +} + +type ContentTool struct { + Id string `json:"id,omitempty"` // tool id + Name string `json:"name,omitempty"` // tool name + Input map[string]any `json:"input"` // tool input + InputJson string `json:"partial_json,omitempty"` // partial json input (for streaming) +} + +type ContentAttachment struct { + Title string `json:"title,omitempty"` // title of the document + Context string `json:"context,omitempty"` // context of the document + Citations *contentcitation `json:"citations,omitempty"` // citations of the document + Source *contentsource `json:"source,omitempty"` // image or document content +} + +type ContentToolResult struct { + Id string `json:"tool_use_id,omitempty"` // tool id + Content any `json:"content,omitempty"` +} + +type contentsource struct { + Type string `json:"type"` // base64 or text + MediaType string `json:"media_type"` // image/jpeg, image/png, image/gif, image/webp, application/pdf, text/plain + Data any `json:"data"` // ...base64 or text encoded data +} + +type cachecontrol struct { + Type string `json:"type"` // ephemeral +} + +type contentcitation struct { + Enabled bool `json:"enabled"` // true +} + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +// Return a Content object with text content +func NewTextContent(v string) *Content { + content := new(Content) + content.Type = "text" + content.ContentText.Text = v + return content +} + +// Return a Content object with tool result +func NewToolResultContent(v llm.ToolResult) *Content { + content := new(Content) + content.Type = "tool_result" + content.ContentToolResult.Id = v.Call().Id() + // content.ContentToolResult.Name = v.Call().Name() + + // We only support JSON encoding for the moment + data, err := json.Marshal(v.Value()) + if err != nil { + content.ContentToolResult.Content = err.Error() + } else { + content.ContentToolResult.Content = string(data) + } + + return content +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (m MessageMeta) String() string { + data, err := json.MarshalIndent(m, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func (m MessageMeta) Text() string { + if len(m.Content) == 0 { + return "" + } + var text []string + for _, content := range m.Content { + if content.Type == "text" { + text = append(text, content.ContentText.Text) + } + } + return strings.Join(text, "\n") +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +var ( + supportedAttachments = map[string]string{ + "image/jpeg": "image", + "image/png": "image", + "image/gif": "image", + "image/webp": "image", + "application/pdf": "document", + "text/plain": "text", + } +) + +// Read content from an io.Reader +func attachmentContent(attachment *llm.Attachment, ephemeral, citations bool) (*Content, error) { + // Detect mimetype + mimetype := http.DetectContentType(attachment.Data()) + if strings.HasPrefix(mimetype, "text/") { + // Switch to text/plain - TODO: charsets? + mimetype = "text/plain" + } + + // Check supported mimetype + typ, exists := supportedAttachments[mimetype] + if !exists { + return nil, llm.ErrBadParameter.Withf("unsupported or undetected mimetype %q", mimetype) + } + + // Create attachment + content := new(Content) + content.Type = typ + if ephemeral { + content.CacheControl = &cachecontrol{Type: "ephemeral"} + } + + // Handle by type + switch typ { + case "text": + content.Type = "document" + content.Title = attachment.Filename() + content.Source = &contentsource{ + Type: "text", + MediaType: mimetype, + Data: string(attachment.Data()), + } + if citations { + content.Citations = &contentcitation{Enabled: true} + } + case "document": + content.Source = &contentsource{ + Type: "base64", + MediaType: mimetype, + Data: attachment.Data(), + } + if citations { + content.Citations = &contentcitation{Enabled: true} + } + case "image": + content.Source = &contentsource{ + Type: "base64", + MediaType: mimetype, + Data: attachment.Data(), + } + default: + return nil, llm.ErrBadParameter.Withf("unsupported attachment type %q", typ) + } + + // Return success + return content, nil +} diff --git a/pkg/anthropic/messages.go b/pkg/anthropic/messages.go new file mode 100644 index 0000000..d8becd3 --- /dev/null +++ b/pkg/anthropic/messages.go @@ -0,0 +1,239 @@ +package anthropic + +import ( + "context" + "encoding/json" + "fmt" + + // Packages + client "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +// Messages Response +type Response struct { + Type string `json:"type"` + Model string `json:"model"` + Id string `json:"id"` + MessageMeta + Reason string `json:"stop_reason,omitempty"` + StopSequence *string `json:"stop_sequence,omitempty"` + Metrics `json:"usage,omitempty"` +} + +// Metrics +type Metrics struct { + CacheCreationInputTokens uint `json:"cache_creation_input_tokens,omitempty"` + CacheReadInputTokens uint `json:"cache_read_input_tokens,omitempty"` + InputTokens uint `json:"input_tokens,omitempty"` + OutputTokens uint `json:"output_tokens,omitempty"` +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (r Response) String() string { + data, err := json.MarshalIndent(r, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +type reqMessages struct { + Model string `json:"model"` + MaxTokens uint `json:"max_tokens,omitempty"` + Metadata *optmetadata `json:"metadata,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Stream bool `json:"stream,omitempty"` + System string `json:"system,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopK uint64 `json:"top_k,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Messages []*MessageMeta `json:"messages"` + Tools []llm.Tool `json:"tools,omitempty"` +} + +func (anthropic *Client) Messages(ctx context.Context, context llm.Context, opts ...llm.Opt) (*Response, error) { + // Apply options + opt, err := llm.ApplyOpts(opts...) + if err != nil { + return nil, err + } + + // Request + req, err := client.NewJSONRequest(reqMessages{ + Model: context.(*session).model.Name(), + Messages: context.(*session).seq, + Tools: optTools(anthropic, opt), + MaxTokens: optMaxTokens(context.(*session).model, opt), + Metadata: optMetadata(opt), + StopSequences: optStopSequences(opt), + Stream: optStream(opt), + System: optSystemPrompt(opt), + Temperature: optTemperature(opt), + TopK: optTopK(opt), + TopP: optTopP(opt), + }) + if err != nil { + return nil, err + } + + // Stream + var response Response + reqopts := []client.RequestOpt{ + client.OptPath("messages"), + } + if optStream(opt) { + // Append delta to content + appendDelta := func(content []*Content, delta *Content) ([]*Content, error) { + if len(content) == 0 { + return nil, fmt.Errorf("unexpected delta") + } + + // Get the content block we want to append to + last := content[len(content)-1] + + // Append text_delta + switch { + case last.Type == "text" && delta.Type == "text_delta": + last.Text += delta.Text + case last.Type == "tool_use" && delta.Type == "input_json_delta": + last.InputJson += delta.InputJson + default: + return nil, fmt.Errorf("unexpected delta %s for %s", delta.Type, last.Type) + } + + // Return the content + return content, nil + } + reqopts = append(reqopts, client.OptTextStreamCallback(func(evt client.TextStreamEvent) error { + switch evt.Event { + case "message_start": + // Start of a message + var r struct { + Type string `json:"type"` + Response Response `json:"message"` + } + if err := evt.Json(&r); err != nil { + return err + } else { + response = r.Response + } + case "content_block_start": + // Start of a content block, append to response + var r struct { + Type string `json:"type"` + Index uint `json:"index"` + Content Content `json:"content_block"` + } + if err := evt.Json(&r); err != nil { + return err + } else if int(r.Index) != len(response.MessageMeta.Content) { + return fmt.Errorf("%s: unexpected index %d", r.Type, r.Index) + } else { + response.MessageMeta.Content = append(response.MessageMeta.Content, &r.Content) + } + case "content_block_delta": + // Continuation of a content block, append to content + var r struct { + Type string `json:"type"` + Index uint `json:"index"` + Content Content `json:"delta"` + } + if err := evt.Json(&r); err != nil { + return err + } else if int(r.Index) != len(response.MessageMeta.Content)-1 { + return fmt.Errorf("%s: unexpected index %d", r.Type, r.Index) + } else if content, err := appendDelta(response.MessageMeta.Content, &r.Content); err != nil { + return err + } else { + response.MessageMeta.Content = content + } + case "content_block_stop": + // End of a content block + var r struct { + Type string `json:"type"` + Index uint `json:"index"` + } + if err := evt.Json(&r); err != nil { + return err + } else if int(r.Index) != len(response.MessageMeta.Content)-1 { + return fmt.Errorf("%s: unexpected index %d", r.Type, r.Index) + } + // We need to convert the partial_json response into a full json object + content := response.MessageMeta.Content[r.Index] + if content.Type == "tool_use" && content.InputJson != "" { + if err := json.Unmarshal([]byte(content.InputJson), &content.Input); err != nil { + return err + } + } + case "message_delta": + // Message update + var r struct { + Type string `json:"type"` + Delta Response `json:"delta"` + Usage Metrics `json:"usage"` + } + if err := evt.Json(&r); err != nil { + return err + } + + // Update stop reason + response.Reason = r.Delta.Reason + response.StopSequence = r.Delta.StopSequence + + // Update metrics + response.Metrics.InputTokens += r.Usage.InputTokens + response.Metrics.OutputTokens += r.Usage.OutputTokens + response.Metrics.CacheCreationInputTokens += r.Usage.CacheCreationInputTokens + response.Metrics.CacheReadInputTokens += r.Usage.CacheReadInputTokens + case "message_stop": + // NO-OP + return nil + case "ping": + // NO-OP + return nil + default: + // NO-OP + return nil + } + + if fn := opt.StreamFn(); fn != nil { + fn(&response) + } + + // Return success + return nil + })) + } + + // Response + if err := anthropic.DoWithContext(ctx, req, &response, reqopts...); err != nil { + return nil, err + } + + // Return success + return &response, nil +} + +/////////////////////////////////////////////////////////////////////////////// +// INTERFACE - CONTEXT CONTENT + +func (response Response) Role() string { + return response.MessageMeta.Role +} + +func (response Response) Text() string { + return response.MessageMeta.Text() +} + +func (response Response) ToolCalls() []llm.ToolCall { + return nil +} diff --git a/pkg/anthropic/messages_test.go b/pkg/anthropic/messages_test.go new file mode 100644 index 0000000..41e9152 --- /dev/null +++ b/pkg/anthropic/messages_test.go @@ -0,0 +1,180 @@ +package anthropic_test + +import ( + "context" + "encoding/json" + "log" + "os" + "testing" + + // Packages + opts "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" + anthropic "github.com/mutablelogic/go-llm/pkg/anthropic" + tool "github.com/mutablelogic/go-llm/pkg/tool" + assert "github.com/stretchr/testify/assert" +) + +func Test_messages_001(t *testing.T) { + assert := assert.New(t) + client, err := anthropic.New(GetApiKey(t), opts.OptTrace(os.Stderr, true)) + if assert.NoError(err) { + assert.NotNil(client) + t.Log(client) + } + + model, err := client.GetModel(context.TODO(), "claude-3-haiku-20240307") + if assert.NoError(err) { + assert.NotNil(client) + t.Log(client) + } else { + t.FailNow() + } + + f, err := os.Open("testdata/guggenheim.jpg") + if !assert.NoError(err) { + t.FailNow() + } + defer f.Close() + + response, err := client.Messages(context.TODO(), model.UserPrompt("what is this image?", llm.WithAttachment(f))) + if assert.NoError(err) { + t.Log(response) + } +} + +func Test_messages_002(t *testing.T) { + assert := assert.New(t) + client, err := anthropic.New(GetApiKey(t), opts.OptTrace(os.Stderr, true)) + if assert.NoError(err) { + assert.NotNil(client) + t.Log(client) + } + + model, err := client.GetModel(context.TODO(), "claude-3-haiku-20240307") + if assert.NoError(err) { + assert.NotNil(client) + t.Log(client) + } else { + t.FailNow() + } + + f, err := os.Open("testdata/LICENSE") + if !assert.NoError(err) { + t.FailNow() + } + defer f.Close() + + response, err := client.Messages(context.TODO(), model.UserPrompt("summarize this document for me", llm.WithAttachment(f))) + if assert.NoError(err) { + t.Log(response) + } +} + +func Test_messages_003(t *testing.T) { + assert := assert.New(t) + client, err := anthropic.New(GetApiKey(t), opts.OptTrace(os.Stderr, true)) + if assert.NoError(err) { + assert.NotNil(client) + t.Log(client) + } + + model, err := client.GetModel(context.TODO(), "claude-3-haiku-20240307") + if assert.NoError(err) { + assert.NotNil(client) + t.Log(client) + } else { + t.FailNow() + } + + response, err := client.Messages(context.TODO(), model.UserPrompt("why is the sky blue"), llm.WithStream(func(r llm.ContextContent) { + t.Log(r) + })) + if assert.NoError(err) { + t.Log(response) + } +} + +func Test_messages_004(t *testing.T) { + assert := assert.New(t) + client, err := anthropic.New(GetApiKey(t), opts.OptTrace(os.Stderr, true)) + if assert.NoError(err) { + assert.NotNil(client) + t.Log(client) + } + + model, err := client.GetModel(context.TODO(), "claude-3-haiku-20240307") + if assert.NoError(err) { + assert.NotNil(client) + t.Log(client) + } else { + t.FailNow() + } + + toolkit := tool.NewToolKit() + if err := toolkit.Register(new(weather)); !assert.NoError(err) { + t.FailNow() + } + + response, err := client.Messages(context.TODO(), model.UserPrompt("why is the sky blue"), llm.WithToolKit(toolkit)) + if assert.NoError(err) { + t.Log(response) + } +} + +func Test_messages_005(t *testing.T) { + assert := assert.New(t) + client, err := anthropic.New(GetApiKey(t), opts.OptTrace(os.Stderr, true)) + if assert.NoError(err) { + assert.NotNil(client) + t.Log(client) + } + + model, err := client.GetModel(context.TODO(), "claude-3-haiku-20240307") + if assert.NoError(err) { + assert.NotNil(client) + t.Log(client) + } else { + t.FailNow() + } + + toolkit := tool.NewToolKit() + if err := toolkit.Register(new(weather)); !assert.NoError(err) { + t.FailNow() + } + + response, err := client.Messages(context.TODO(), model.UserPrompt("why is the sky blue"), llm.WithStream(func(r llm.ContextContent) { + t.Log(r) + }), llm.WithToolKit(toolkit)) + if assert.NoError(err) { + t.Log(response) + } +} + +//////////////////////////////////////////////////////////////////////////////// +// TOOLS + +type weather struct { + Location string `json:"location" name:"location" help:"The location to get the weather for" required:"true"` +} + +func (*weather) Name() string { + return "weather_in_location" +} + +func (*weather) Description() string { + return "Get the weather in a location" +} + +func (weather *weather) String() string { + data, err := json.MarshalIndent(weather, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +func (weather *weather) Run(ctx context.Context) (any, error) { + log.Println("weather_in_location", "=>", weather) + return "very sunny today", nil +} diff --git a/pkg/anthropic/model.go b/pkg/anthropic/model.go new file mode 100644 index 0000000..288cb47 --- /dev/null +++ b/pkg/anthropic/model.go @@ -0,0 +1,96 @@ +package anthropic + +import ( + "context" + "net/url" + "time" + + // Packages + client "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +// model is the implementation of the llm.Model interface +type model struct { + client *Client + ModelMeta +} + +var _ llm.Model = (*model)(nil) + +// ModelMeta is the metadata for an anthropic model +type ModelMeta struct { + Name string `json:"id"` + Description string `json:"display_name,omitempty"` + Type string `json:"type,omitempty"` + CreatedAt *time.Time `json:"created_at,omitempty"` +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Agent interface +func (anthropic *Client) Models(ctx context.Context) ([]llm.Model, error) { + return anthropic.ListModels(ctx) +} + +// Get a model by name +func (anthropic *Client) GetModel(ctx context.Context, name string) (llm.Model, error) { + var response ModelMeta + if err := anthropic.DoWithContext(ctx, nil, &response, client.OptPath("models", name)); err != nil { + return nil, err + } + + // Return success + return &model{client: anthropic, ModelMeta: response}, nil +} + +// List models +func (anthropic *Client) ListModels(ctx context.Context) ([]llm.Model, error) { + // Send the request + var response struct { + Body []ModelMeta `json:"data"` + HasMore bool `json:"has_more"` + FirstId string `json:"first_id"` + LastId string `json:"last_id"` + } + + request := url.Values{} + result := make([]llm.Model, 0, 100) + for { + if err := anthropic.DoWithContext(ctx, nil, &response, client.OptPath("models"), client.OptQuery(request)); err != nil { + return nil, err + } + + // Convert to llm.Model + for _, meta := range response.Body { + result = append(result, &model{ + client: anthropic, + ModelMeta: meta, + }) + } + + // If there are no more models, return + if !response.HasMore { + break + } else { + request.Set("after_id", response.LastId) + } + } + + // Return models + return result, nil +} + +// Return the name of a model +func (model *model) Name() string { + return model.ModelMeta.Name +} + +// Embedding vector generation - not supported on Anthropic +func (*model) Embedding(context.Context, string, ...llm.Opt) ([]float64, error) { + return nil, llm.ErrNotImplemented +} diff --git a/pkg/anthropic/opt.go b/pkg/anthropic/opt.go new file mode 100644 index 0000000..3f64a42 --- /dev/null +++ b/pkg/anthropic/opt.go @@ -0,0 +1,120 @@ +package anthropic + +import ( + "strings" + + // Packages + llm "github.com/mutablelogic/go-llm" +) + +//////////////////////////////////////////////////////////////////////////////// +// TYPES + +type optmetadata struct { + User string `json:"user_id,omitempty"` +} + +//////////////////////////////////////////////////////////////////////////////// +// OPTIONS + +func WithMaxTokens(v uint) llm.Opt { + return func(o *llm.Opts) error { + o.Set("max_tokens", v) + return nil + } +} + +func WithUser(v string) llm.Opt { + return func(o *llm.Opts) error { + o.Set("user", v) + return nil + } +} + +func WithStopSequences(v ...string) llm.Opt { + return func(o *llm.Opts) error { + o.Set("stop", v) + return nil + } +} + +func WithEphemeral() llm.Opt { + return func(o *llm.Opts) error { + o.Set("ephemeral", true) + return nil + } +} + +func WithCitations() llm.Opt { + return func(o *llm.Opts) error { + o.Set("citations", true) + return nil + } +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func optCitations(opt *llm.Opts) bool { + return opt.GetBool("citations") +} + +func optEphemeral(opt *llm.Opts) bool { + return opt.GetBool("ephemeral") +} + +func optTools(agent *Client, opts *llm.Opts) []llm.Tool { + toolkit := opts.ToolKit() + if toolkit == nil { + return nil + } + return toolkit.Tools(agent) +} + +func optMaxTokens(model llm.Model, opt *llm.Opts) uint { + // https://docs.anthropic.com/en/docs/about-claude/models + switch { + case strings.Contains(model.Name(), "claude-3-5-haiku"): + return 8192 + case strings.Contains(model.Name(), "claude-3-5-sonnet"): + return 8192 + default: + return 4096 + } +} + +func optMetadata(opt *llm.Opts) *optmetadata { + if user, ok := opt.Get("user").(string); ok { + return &optmetadata{User: user} + } + return nil +} + +func optStopSequences(opt *llm.Opts) []string { + if opt.Has("stop") { + if stop, ok := opt.Get("stop").([]string); ok { + return stop + } + } + return nil +} + +func optStream(opt *llm.Opts) bool { + return opt.StreamFn() != nil +} + +func optSystemPrompt(opt *llm.Opts) string { + return opt.SystemPrompt() +} + +func optTemperature(opt *llm.Opts) float64 { + return opt.GetFloat64("temperature") +} + +func optTopK(opt *llm.Opts) uint64 { + return opt.GetUint64("top_k") +} + +func optTopP(opt *llm.Opts) float64 { + return opt.GetFloat64("top_p") +} diff --git a/pkg/anthropic/session.go b/pkg/anthropic/session.go new file mode 100644 index 0000000..d545582 --- /dev/null +++ b/pkg/anthropic/session.go @@ -0,0 +1,213 @@ +package anthropic + +import ( + "context" + "encoding/json" + + // Packages + llm "github.com/mutablelogic/go-llm" + tool "github.com/mutablelogic/go-llm/pkg/tool" +) + +////////////////////////////////////////////////////////////////// +// TYPES + +type session struct { + model *model + opts []llm.Opt + seq []*MessageMeta +} + +var _ llm.Context = (*session)(nil) + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +// Return an empty session context object for the model, setting session options +func (model *model) Context(opts ...llm.Opt) llm.Context { + return &session{ + model: model, + opts: opts, + } +} + +// Convenience method to create a session context object with a user prompt, which +// panics on error +func (model *model) UserPrompt(prompt string, opts ...llm.Opt) llm.Context { + context := model.Context(opts...) + + meta, err := userPrompt(prompt, opts...) + if err != nil { + panic(err) + } + + // Add to the sequence + context.(*session).seq = append(context.(*session).seq, meta) + + // Return success + return context +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (session session) String() string { + var data []byte + var err error + if len(session.seq) == 1 { + data, err = json.MarshalIndent(session.seq[0], "", " ") + } else { + data, err = json.MarshalIndent(session.seq, "", " ") + } + if err != nil { + return err.Error() + } + return string(data) +} + +////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Return the role of the last message +func (session *session) Role() string { + if len(session.seq) == 0 { + return "" + } + return session.seq[len(session.seq)-1].Role +} + +// Return the text of the last message +func (session *session) Text() string { + if len(session.seq) == 0 { + return "" + } + meta := session.seq[len(session.seq)-1] + return meta.Text() +} + +// Return the current session tool calls, or empty if no tool calls were made +func (session *session) ToolCalls() []llm.ToolCall { + // Sanity check for tool call + if len(session.seq) == 0 { + return nil + } + meta := session.seq[len(session.seq)-1] + if meta.Role != "assistant" { + return nil + } + + // Gather tool calls + var result []llm.ToolCall + for _, content := range meta.Content { + if content.Type == "tool_use" { + result = append(result, tool.NewCall(content.ContentTool.Id, content.ContentTool.Name, content.ContentTool.Input)) + } + } + return result +} + +// Generate a response from a user prompt (with attachments) and +// other empheral options +func (session *session) FromUser(ctx context.Context, prompt string, opts ...llm.Opt) error { + // Append the user prompt to the sequence + meta, err := userPrompt(prompt, opts...) + if err != nil { + return err + } else { + session.seq = append(session.seq, meta) + } + + // The options come from the session options and the user options + chatopts := make([]llm.Opt, 0, len(session.opts)+len(opts)) + chatopts = append(chatopts, session.opts...) + chatopts = append(chatopts, opts...) + + // Call the 'chat' method + client := session.model.client + r, err := client.Messages(ctx, session, chatopts...) + if err != nil { + return err + } else { + session.seq = append(session.seq, &r.MessageMeta) + } + + // Return success + return nil +} + +// Generate a response from a tool, passing the call identifier or +// function name, and the result +func (session *session) FromTool(ctx context.Context, results ...llm.ToolResult) error { + meta, err := toolResults(results...) + if err != nil { + return err + } else { + session.seq = append(session.seq, meta) + } + + // Call the 'chat' method + client := session.model.client + r, err := client.Messages(ctx, session, session.opts...) + if err != nil { + return err + } else { + session.seq = append(session.seq, &r.MessageMeta) + } + + // Return success + return nil +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func userPrompt(prompt string, opts ...llm.Opt) (*MessageMeta, error) { + // Apply attachments + opt, err := llm.ApplyOpts(opts...) + if err != nil { + return nil, err + } + + // Get attachments + attachments := opt.Attachments() + + // Create user message + meta := MessageMeta{ + Role: "user", + Content: make([]*Content, 1, len(attachments)+1), + } + + // Append the text + meta.Content[0] = NewTextContent(prompt) + + // Append any additional data + for _, attachment := range attachments { + content, err := attachmentContent(attachment, optEphemeral(opt), optCitations(opt)) + if err != nil { + return nil, err + } + meta.Content = append(meta.Content, content) + } + + // Return success + return &meta, nil +} + +func toolResults(results ...llm.ToolResult) (*MessageMeta, error) { + // Check for no results + if len(results) == 0 { + return nil, llm.ErrBadParameter.Withf("No tool results") + } + + // Create user message + meta := MessageMeta{ + Role: "user", + Content: make([]*Content, 0, len(results)), + } + for _, result := range results { + meta.Content = append(meta.Content, NewToolResultContent(result)) + } + + // Return success + return &meta, nil +} diff --git a/pkg/anthropic/session_test.go b/pkg/anthropic/session_test.go new file mode 100644 index 0000000..78c01de --- /dev/null +++ b/pkg/anthropic/session_test.go @@ -0,0 +1,91 @@ +package anthropic_test + +import ( + "context" + "os" + "testing" + + // Packages + opts "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" + anthropic "github.com/mutablelogic/go-llm/pkg/anthropic" + tool "github.com/mutablelogic/go-llm/pkg/tool" + assert "github.com/stretchr/testify/assert" +) + +func Test_session_001(t *testing.T) { + client, err := anthropic.New(GetApiKey(t), opts.OptTrace(os.Stderr, true)) + if err != nil { + t.FailNow() + } + + model, err := client.GetModel(context.TODO(), "claude-3-haiku-20240307") + if err != nil { + t.FailNow() + } + + // Session with a single user prompt - streaming + t.Run("stream", func(t *testing.T) { + assert := assert.New(t) + session := model.Context(llm.WithStream(func(stream llm.ContextContent) { + t.Log("SESSION DELTA", stream) + })) + assert.NotNil(session) + + err := session.FromUser(context.TODO(), "Why is the grass green?") + if !assert.NoError(err) { + t.FailNow() + } + assert.Equal("assistant", session.Role()) + assert.NotEmpty(session.Text()) + }) + + // Session with a single user prompt - not streaming + t.Run("nostream", func(t *testing.T) { + assert := assert.New(t) + session := model.Context() + assert.NotNil(session) + + err := session.FromUser(context.TODO(), "Why is the sky blue?") + if !assert.NoError(err) { + t.FailNow() + } + assert.Equal("assistant", session.Role()) + assert.NotEmpty(session.Text()) + }) +} + +func Test_session_002(t *testing.T) { + client, err := anthropic.New(GetApiKey(t), opts.OptTrace(os.Stderr, true)) + if err != nil { + t.FailNow() + } + + model, err := client.GetModel(context.TODO(), "claude-3-haiku-20240307") + if err != nil { + t.FailNow() + } + + // Session with a tool call + t.Run("toolcall", func(t *testing.T) { + assert := assert.New(t) + + toolkit := tool.NewToolKit() + if err := toolkit.Register(new(weather)); !assert.NoError(err) { + t.FailNow() + } + + session := model.Context(llm.WithToolKit(toolkit)) + assert.NotNil(session) + + err = session.FromUser(context.TODO(), "What is today's weather, in Berlin?") + if !assert.NoError(err) { + t.FailNow() + } + + err := toolkit.Run(context.TODO(), session.ToolCalls()...) + if !assert.NoError(err) { + t.FailNow() + } + }) +} diff --git a/pkg/anthropic/testdata/LICENSE b/pkg/anthropic/testdata/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/pkg/anthropic/testdata/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/pkg/anthropic/testdata/guggenheim.jpg b/pkg/anthropic/testdata/guggenheim.jpg new file mode 100644 index 0000000..7e16517 Binary files /dev/null and b/pkg/anthropic/testdata/guggenheim.jpg differ diff --git a/pkg/newsapi/README.md b/pkg/newsapi/README.md new file mode 100644 index 0000000..ac53b94 --- /dev/null +++ b/pkg/newsapi/README.md @@ -0,0 +1,8 @@ +# NewsAPI Client + +This package provides a client for the NewsAPI API, which is used to interact with the NewsAPI service. + +References: + +- API https://newsapi.org/docs +- Package https://pkg.go.dev/github.com/mutablelogic/go-llm/pkg/newsapi diff --git a/pkg/newsapi/agent.go b/pkg/newsapi/agent.go new file mode 100644 index 0000000..ee0f658 --- /dev/null +++ b/pkg/newsapi/agent.go @@ -0,0 +1,88 @@ +package newsapi + +import ( + "context" + "fmt" + "slices" + + // Packages + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// HEADLINES + +type headlines struct { + *Client `json:"-"` +} + +var _ llm.Tool = (*headlines)(nil) + +func (headlines) Name() string { + return "news_headlines" +} + +func (headlines) Description() string { + return "Return the current global news headlines" +} + +func (headlines *headlines) Run(ctx context.Context) (any, error) { + return headlines.Headlines(OptCategory("general"), OptLimit(10)) +} + +/////////////////////////////////////////////////////////////////////////////// +// SEARCH + +type search struct { + *Client `json:"-"` + Query string `json:"query" help:"A phrase used to search for news headlines." required:"true"` +} + +var _ llm.Tool = (*search)(nil) + +func (search) Name() string { + return "news_search" +} + +func (search) Description() string { + return "Search the news archive with a search query" +} + +func (search *search) Run(ctx context.Context) (any, error) { + if search.Query == "" { + return nil, nil + } + fmt.Printf(" => Search for %q\n", search.Query) + return search.Articles(OptQuery(search.Query), OptLimit(10)) +} + +/////////////////////////////////////////////////////////////////////////////// +// CATEGORY + +type category struct { + *Client `json:"-"` + Category string `json:"category" enum:"business, entertainment, health, science, sports, technology" help:"business, entertainment, health, science, sports, technology" required:"true"` +} + +var _ llm.Tool = (*category)(nil) + +var ( + categories = []string{"business", "entertainment", "health", "science", "sports", "technology"} +) + +func (category) Name() string { + return "news_headlines_category" +} + +func (category) Description() string { + return "Return the news headlines for a specific category" +} + +func (category *category) Run(ctx context.Context) (any, error) { + if !slices.Contains(categories, category.Category) { + fmt.Printf(" => Search for %q\n", category.Category) + return category.Articles(OptQuery(category.Category), OptLimit(10)) + } + fmt.Printf(" => Headlines for %q\n", category.Category) + return category.Headlines(OptCategory(category.Category), OptLimit(10)) +} diff --git a/pkg/newsapi/articles.go b/pkg/newsapi/articles.go new file mode 100644 index 0000000..f16b2f3 --- /dev/null +++ b/pkg/newsapi/articles.go @@ -0,0 +1,82 @@ +package newsapi + +import ( + "time" + + // Packages + "github.com/mutablelogic/go-client" + + // Namespace imports + . "github.com/djthorpe/go-errors" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type Article struct { + Source Source `json:"source"` + Title string `json:"title"` + Author string `json:"author,omitempty"` + Description string `json:"description,omitempty"` + Url string `json:"url,omitempty"` + ImageUrl string `json:"urlToImage,omitempty"` + PublishedAt time.Time `json:"publishedAt,omitempty"` + Content string `json:"content,omitempty"` +} + +type respArticles struct { + Status string `json:"status"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + TotalResults int `json:"totalResults"` + Articles []Article `json:"articles"` +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Returns headlines +func (c *Client) Headlines(opt ...Opt) ([]Article, error) { + var response respArticles + var query opts + + // Add options + for _, opt := range opt { + if err := opt(&query); err != nil { + return nil, err + } + } + + // Request -> Response + if err := c.Do(nil, &response, client.OptPath("top-headlines"), client.OptQuery(query.Values())); err != nil { + return nil, err + } else if response.Status != "ok" { + return nil, ErrBadParameter.Withf("%s: %s", response.Code, response.Message) + } + + // Return success + return response.Articles, nil +} + +// Returns articles +func (c *Client) Articles(opt ...Opt) ([]Article, error) { + var response respArticles + var query opts + + // Add options + for _, opt := range opt { + if err := opt(&query); err != nil { + return nil, err + } + } + + // Request -> Response + if err := c.Do(nil, &response, client.OptPath("everything"), client.OptQuery(query.Values())); err != nil { + return nil, err + } else if response.Status != "ok" { + return nil, ErrBadParameter.Withf("%s: %s", response.Code, response.Message) + } + + // Return success + return response.Articles, nil +} diff --git a/pkg/newsapi/articles_test.go b/pkg/newsapi/articles_test.go new file mode 100644 index 0000000..8c44bd5 --- /dev/null +++ b/pkg/newsapi/articles_test.go @@ -0,0 +1,38 @@ +package newsapi_test + +import ( + "encoding/json" + "os" + "testing" + + // Packages + opts "github.com/mutablelogic/go-client" + newsapi "github.com/mutablelogic/go-llm/pkg/newsapi" + assert "github.com/stretchr/testify/assert" +) + +func Test_articles_001(t *testing.T) { + assert := assert.New(t) + client, err := newsapi.New(GetApiKey(t), opts.OptTrace(os.Stderr, true)) + assert.NoError(err) + + articles, err := client.Headlines(newsapi.OptQuery("google")) + assert.NoError(err) + assert.NotNil(articles) + + body, _ := json.MarshalIndent(articles, "", " ") + t.Log(string(body)) +} + +func Test_articles_002(t *testing.T) { + assert := assert.New(t) + client, err := newsapi.New(GetApiKey(t), opts.OptTrace(os.Stderr, true)) + assert.NoError(err) + + articles, err := client.Articles(newsapi.OptQuery("google"), newsapi.OptLimit(1)) + assert.NoError(err) + assert.NotNil(articles) + + body, _ := json.MarshalIndent(articles, "", " ") + t.Log(string(body)) +} diff --git a/pkg/newsapi/client.go b/pkg/newsapi/client.go new file mode 100644 index 0000000..6851e35 --- /dev/null +++ b/pkg/newsapi/client.go @@ -0,0 +1,57 @@ +/* +newsapi implements an API client for NewsAPI (https://newsapi.org/docs) +*/ +package newsapi + +import ( + // Packages + "github.com/mutablelogic/go-client" + "github.com/mutablelogic/go-llm/pkg/tool" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type Client struct { + *client.Client +} + +/////////////////////////////////////////////////////////////////////////////// +// GLOBALS + +const ( + endPoint = "https://newsapi.org/v2" +) + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +func New(ApiKey string, opts ...client.ClientOpt) (*Client, error) { + // Create client + client, err := client.New(append(opts, client.OptEndpoint(endPoint), client.OptHeader("X-Api-Key", ApiKey))...) + if err != nil { + return nil, err + } + + // Return the client + return &Client{client}, nil +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +func (newsapi *Client) RegisterWithToolKit(toolkit *tool.ToolKit) error { + // Register tools + if err := toolkit.Register(&headlines{newsapi}); err != nil { + return err + } + if err := toolkit.Register(&search{newsapi, ""}); err != nil { + return err + } + if err := toolkit.Register(&category{newsapi, ""}); err != nil { + return err + } + + // Return success + return nil +} diff --git a/pkg/newsapi/client_test.go b/pkg/newsapi/client_test.go new file mode 100644 index 0000000..bf641b5 --- /dev/null +++ b/pkg/newsapi/client_test.go @@ -0,0 +1,31 @@ +package newsapi_test + +import ( + "os" + "testing" + + // Packages + opts "github.com/mutablelogic/go-client" + newsapi "github.com/mutablelogic/go-llm/pkg/newsapi" + assert "github.com/stretchr/testify/assert" +) + +func Test_client_001(t *testing.T) { + assert := assert.New(t) + client, err := newsapi.New(GetApiKey(t), opts.OptTrace(os.Stderr, true)) + assert.NoError(err) + assert.NotNil(client) + t.Log(client) +} + +/////////////////////////////////////////////////////////////////////////////// +// ENVIRONMENT + +func GetApiKey(t *testing.T) string { + key := os.Getenv("NEWSAPI_KEY") + if key == "" { + t.Skip("NEWSAPI_KEY not set") + t.SkipNow() + } + return key +} diff --git a/pkg/newsapi/opts.go b/pkg/newsapi/opts.go new file mode 100644 index 0000000..e43c727 --- /dev/null +++ b/pkg/newsapi/opts.go @@ -0,0 +1,114 @@ +package newsapi + +import ( + "fmt" + "net/url" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type opts struct { + Category string `json:"category,omitempty"` + Language string `json:"language,omitempty"` + Country string `json:"country,omitempty"` + Query string `json:"q,omitempty"` + Limit int `json:"pageSize,omitempty"` + Sort string `json:"sortBy,omitempty"` +} + +// Opt is a function which can be used to set options on a request +type Opt func(*opts) error + +/////////////////////////////////////////////////////////////////////////////// +// METHODS + +func (o *opts) Values() url.Values { + result := url.Values{} + if o.Category != "" { + result.Set("category", o.Category) + } + if o.Language != "" { + result.Set("language", o.Language) + } + if o.Country != "" { + result.Set("country", o.Country) + } + if o.Query != "" { + result.Set("q", o.Query) + } + if o.Limit > 0 { + result.Set("pageSize", fmt.Sprint(o.Limit)) + } + if o.Sort != "" { + result.Set("sortBy", o.Sort) + } + return result +} + +/////////////////////////////////////////////////////////////////////////////// +// OPTIONS + +// Set the category +func OptCategory(v string) Opt { + return func(o *opts) error { + o.Category = v + return nil + } +} + +// Set the language +func OptLanguage(v string) Opt { + return func(o *opts) error { + o.Language = v + return nil + } +} + +// Set the country +func OptCountry(v string) Opt { + return func(o *opts) error { + o.Country = v + return nil + } +} + +// Set the query +func OptQuery(v string) Opt { + return func(o *opts) error { + o.Query = v + return nil + } +} + +// Set the number of results +func OptLimit(v int) Opt { + return func(o *opts) error { + o.Limit = v + return nil + } +} + +// Sort for articles by relevancy +func OptSortByRelevancy() Opt { + return func(o *opts) error { + o.Sort = "relevancy" + return nil + } +} + +// Sort for articles by popularity +func OptSortByPopularity() Opt { + return func(o *opts) error { + o.Sort = "popularity" + return nil + } +} + +// Sort for articles by date +func OptSortByDate() Opt { + return func(o *opts) error { + o.Sort = "publishedAt" + return nil + } +} diff --git a/pkg/newsapi/sources.go b/pkg/newsapi/sources.go new file mode 100644 index 0000000..ef7aa4c --- /dev/null +++ b/pkg/newsapi/sources.go @@ -0,0 +1,63 @@ +package newsapi + +import ( + // Packages + "github.com/mutablelogic/go-client" + + // Namespace imports + . "github.com/djthorpe/go-errors" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type Source struct { + Id string `json:"id"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + Url string `json:"url,omitempty"` + Category string `json:"category,omitempty"` + Language string `json:"language,omitempty"` + Country string `json:"country,omitempty"` +} + +type respSources struct { + Status string `json:"status"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Sources []Source `json:"sources"` +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Sources returns all the models. The options which can be passed are: +// +// OptCategory: The category you would like to get sources for. Possible +// options are business, entertainment, general, health, science, sports, +// technology. +// +// OptLanguage: The language you would like to get sources for +// +// OptCountry: The country you would like to get sources for +func (c *Client) Sources(opt ...Opt) ([]Source, error) { + var response respSources + var query opts + + // Add options + for _, opt := range opt { + if err := opt(&query); err != nil { + return nil, err + } + } + + // Request -> Response + if err := c.Do(nil, &response, client.OptPath("top-headlines/sources"), client.OptQuery(query.Values())); err != nil { + return nil, err + } else if response.Status != "ok" { + return nil, ErrBadParameter.Withf("%s: %s", response.Code, response.Message) + } + + // Return success + return response.Sources, nil +} diff --git a/pkg/newsapi/sources_test.go b/pkg/newsapi/sources_test.go new file mode 100644 index 0000000..d61ac1b --- /dev/null +++ b/pkg/newsapi/sources_test.go @@ -0,0 +1,25 @@ +package newsapi_test + +import ( + "encoding/json" + "os" + "testing" + + // Packages + opts "github.com/mutablelogic/go-client" + newsapi "github.com/mutablelogic/go-llm/pkg/newsapi" + assert "github.com/stretchr/testify/assert" +) + +func Test_sources_001(t *testing.T) { + assert := assert.New(t) + client, err := newsapi.New(GetApiKey(t), opts.OptTrace(os.Stderr, true)) + assert.NoError(err) + + sources, err := client.Sources(newsapi.OptLanguage("en")) + assert.NoError(err) + assert.NotNil(sources) + + body, err := json.MarshalIndent(sources, "", " ") + t.Log(string(body)) +} diff --git a/pkg/ollama/chat.go b/pkg/ollama/chat.go new file mode 100644 index 0000000..14bf5f5 --- /dev/null +++ b/pkg/ollama/chat.go @@ -0,0 +1,139 @@ +package ollama + +import ( + "context" + "encoding/json" + "time" + + // Packages + client "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +// Chat Response +type Response struct { + Model string `json:"model"` + CreatedAt time.Time `json:"created_at"` + Message MessageMeta `json:"message"` + Done bool `json:"done"` + Reason string `json:"done_reason,omitempty"` + Metrics +} + +// Metrics +type Metrics struct { + TotalDuration time.Duration `json:"total_duration,omitempty"` + LoadDuration time.Duration `json:"load_duration,omitempty"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` + EvalCount int `json:"eval_count,omitempty"` + EvalDuration time.Duration `json:"eval_duration,omitempty"` +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (r Response) String() string { + data, err := json.MarshalIndent(r, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +type reqChat struct { + Model string `json:"model"` + Messages []*MessageMeta `json:"messages"` + Tools []ToolFunction `json:"tools,omitempty"` + Format string `json:"format,omitempty"` + Options map[string]interface{} `json:"options,omitempty"` + Stream bool `json:"stream"` + KeepAlive *time.Duration `json:"keep_alive,omitempty"` +} + +func (ollama *Client) Chat(ctx context.Context, prompt llm.Context, opts ...llm.Opt) (*Response, error) { + opt, err := llm.ApplyOpts(opts...) + if err != nil { + return nil, err + } + + // Append the system prompt at the beginning + seq := make([]*MessageMeta, 0, len(prompt.(*session).seq)+1) + if system := opt.SystemPrompt(); system != "" { + seq = append(seq, &MessageMeta{ + Role: "system", + Content: opt.SystemPrompt(), + }) + } + seq = append(seq, prompt.(*session).seq...) + + // Request + req, err := client.NewJSONRequest(reqChat{ + Model: prompt.(*session).model.Name(), + Messages: seq, + Tools: optTools(ollama, opt), + Format: optFormat(opt), + Options: optOptions(opt), + Stream: optStream(ollama, opt), + KeepAlive: optKeepAlive(opt), + }) + if err != nil { + return nil, err + } + + // Response + var response, delta Response + if err := ollama.DoWithContext(ctx, req, &delta, client.OptPath("chat"), client.OptJsonStreamCallback(func(v any) error { + if v, ok := v.(*Response); !ok || v == nil { + return llm.ErrConflict.Withf("Invalid stream response: %v", v) + } else { + response.Model = v.Model + response.CreatedAt = v.CreatedAt + response.Message.Role = v.Message.Role + response.Message.Content += v.Message.Content + if v.Done { + response.Done = v.Done + response.Metrics = v.Metrics + response.Reason = v.Reason + } + } + + //Call the chat callback + if optStream(ollama, opt) { + if fn := opt.StreamFn(); fn != nil { + fn(&response) + } + } + return nil + })); err != nil { + return nil, err + } + + // We return the delta or the response + if optStream(ollama, opt) { + return &response, nil + } else { + return &delta, nil + } +} + +/////////////////////////////////////////////////////////////////////////////// +// INTERFACE - CONTEXT CONTENT + +func (response Response) Role() string { + return response.Message.Role +} + +func (response Response) Text() string { + return response.Message.Content +} + +func (response Response) ToolCalls() []llm.ToolCall { + return nil +} diff --git a/pkg/ollama/chat_test.go b/pkg/ollama/chat_test.go new file mode 100644 index 0000000..cbc77c5 --- /dev/null +++ b/pkg/ollama/chat_test.go @@ -0,0 +1,146 @@ +package ollama_test + +import ( + "context" + "encoding/json" + "log" + "os" + "testing" + + // Packages + opts "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" + ollama "github.com/mutablelogic/go-llm/pkg/ollama" + tool "github.com/mutablelogic/go-llm/pkg/tool" + assert "github.com/stretchr/testify/assert" +) + +func Test_chat_001(t *testing.T) { + client, err := ollama.New(GetEndpoint(t), opts.OptTrace(os.Stderr, true)) + if err != nil { + t.FailNow() + } + + // Pull the model + model, err := client.PullModel(context.TODO(), "qwen:0.5b", ollama.WithPullStatus(func(status *ollama.PullStatus) { + t.Log(status) + })) + if err != nil { + t.FailNow() + } + + t.Run("ChatStream", func(t *testing.T) { + assert := assert.New(t) + response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), llm.WithStream(func(stream llm.Context) { + t.Log(stream) + })) + if !assert.NoError(err) { + t.FailNow() + } + t.Log(response) + }) + + t.Run("ChatNoStream", func(t *testing.T) { + assert := assert.New(t) + response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky green?")) + if !assert.NoError(err) { + t.FailNow() + } + t.Log(response) + }) +} + +func Test_chat_002(t *testing.T) { + client, err := ollama.New(GetEndpoint(t), opts.OptTrace(os.Stderr, true)) + if err != nil { + t.FailNow() + } + + // Pull the model + model, err := client.PullModel(context.TODO(), "llama3.2:1b", ollama.WithPullStatus(func(status *ollama.PullStatus) { + t.Log(status) + })) + if err != nil { + t.FailNow() + } + + // Make a toolkit + toolkit := tool.NewToolKit() + if err := toolkit.Register(new(weather)); err != nil { + t.FailNow() + } + + t.Run("Tools", func(t *testing.T) { + assert := assert.New(t) + response, err := client.Chat(context.TODO(), + model.UserPrompt("what is the weather in berlin?"), + llm.WithToolKit(toolkit), + ) + if !assert.NoError(err) { + t.FailNow() + } + t.Log(response) + }) +} + +func Test_chat_003(t *testing.T) { + client, err := ollama.New(GetEndpoint(t), opts.OptTrace(os.Stderr, false)) + if err != nil { + t.FailNow() + } + + // Pull the model + model, err := client.PullModel(context.TODO(), "llava", ollama.WithPullStatus(func(status *ollama.PullStatus) { + t.Log(status) + })) + if err != nil { + t.FailNow() + } + + // Explain the content of an image + t.Run("Image", func(t *testing.T) { + assert := assert.New(t) + + f, err := os.Open("testdata/guggenheim.jpg") + if !assert.NoError(err) { + t.FailNow() + } + defer f.Close() + + response, err := client.Chat(context.TODO(), + model.UserPrompt("describe this photo to me", llm.WithAttachment(f)), + ) + if !assert.NoError(err) { + t.FailNow() + } + t.Log(response) + }) +} + +//////////////////////////////////////////////////////////////////////////////// +// TOOLS + +type weather struct { + Location string `json:"location" name:"location" help:"The location to get the weather for" required:"true"` +} + +func (*weather) Name() string { + return "weather_in_location" +} + +func (*weather) Description() string { + return "Get the weather in a location" +} + +func (weather *weather) String() string { + data, err := json.MarshalIndent(weather, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +func (weather *weather) Run(ctx context.Context) (any, error) { + log.Println("weather_in_location", "=>", weather) + return "very sunny today", nil +} diff --git a/pkg/ollama/client.go b/pkg/ollama/client.go new file mode 100644 index 0000000..56d9c62 --- /dev/null +++ b/pkg/ollama/client.go @@ -0,0 +1,49 @@ +package ollama + +import ( + + // Packages + client "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type Client struct { + *client.Client +} + +// Ensure it satisfies the agent.Agent interface +var _ llm.Agent = (*Client)(nil) + +/////////////////////////////////////////////////////////////////////////////// +// GLOBALS + +const ( + defaultName = "ollama" +) + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +// Create a new client, with an ollama endpoint, which should be something like +// "http://localhost:11434/api" +func New(endPoint string, opts ...client.ClientOpt) (*Client, error) { + // Create client + client, err := client.New(append(opts, client.OptEndpoint(endPoint))...) + if err != nil { + return nil, err + } + + // Return the client + return &Client{client}, nil +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Return the name of the agent +func (*Client) Name() string { + return defaultName +} diff --git a/pkg/ollama/client_test.go b/pkg/ollama/client_test.go new file mode 100644 index 0000000..851b98b --- /dev/null +++ b/pkg/ollama/client_test.go @@ -0,0 +1,32 @@ +package ollama_test + +import ( + "os" + "testing" + + // Packages + opts "github.com/mutablelogic/go-client" + ollama "github.com/mutablelogic/go-llm/pkg/ollama" + assert "github.com/stretchr/testify/assert" +) + +func Test_client_001(t *testing.T) { + assert := assert.New(t) + client, err := ollama.New(GetEndpoint(t), opts.OptTrace(os.Stderr, true)) + if assert.NoError(err) { + assert.NotNil(client) + t.Log(client) + } +} + +/////////////////////////////////////////////////////////////////////////////// +// ENVIRONMENT + +func GetEndpoint(t *testing.T) string { + key := os.Getenv("OLLAMA_URL") + if key == "" { + t.Skip("OLLAMA_URL not set, skipping tests") + t.SkipNow() + } + return key +} diff --git a/pkg/ollama/doc.go b/pkg/ollama/doc.go new file mode 100644 index 0000000..c652fb9 --- /dev/null +++ b/pkg/ollama/doc.go @@ -0,0 +1,5 @@ +/* +ollama implements an API client for ollama +https://github.com/ollama/ollama/blob/main/docs/api.md +*/ +package ollama diff --git a/pkg/ollama/embedding.go b/pkg/ollama/embedding.go new file mode 100644 index 0000000..ceae604 --- /dev/null +++ b/pkg/ollama/embedding.go @@ -0,0 +1,95 @@ +package ollama + +import ( + "context" + "encoding/json" + "time" + + // Packages + client "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +// model is the implementation of the llm.Embedding interface +type embedding struct { + EmbeddingMeta +} + +// EmbeddingMeta is the metadata for a generated embedding vector +type EmbeddingMeta struct { + Model string `json:"model"` + Embeddings [][]float64 `json:"embeddings"` + Metrics +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (m embedding) String() string { + data, err := json.MarshalIndent(m.EmbeddingMeta, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +func (m EmbeddingMeta) String() string { + data, err := json.MarshalIndent(m, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +type reqEmbedding struct { + Model string `json:"model"` + Input []string `json:"input"` + KeepAlive *time.Duration `json:"keep_alive,omitempty"` + Truncate *bool `json:"truncate,omitempty"` + Options map[string]interface{} `json:"options,omitempty"` +} + +func (ollama *Client) GenerateEmbedding(ctx context.Context, name string, prompt []string, opts ...llm.Opt) (*EmbeddingMeta, error) { + // Apply options + opt, err := llm.ApplyOpts(opts...) + if err != nil { + return nil, err + } + + // Bail out is no prompt + if len(prompt) == 0 { + return nil, llm.ErrBadParameter.With("missing prompt") + } + + // Request + req, err := client.NewJSONRequest(reqEmbedding{ + Model: name, + Input: prompt, + Truncate: optTruncate(opt), + KeepAlive: optKeepAlive(opt), + Options: optOptions(opt), + }) + if err != nil { + return nil, err + } + + // Response + var response embedding + if err := ollama.DoWithContext(ctx, req, &response, client.OptPath("embed")); err != nil { + return nil, err + } + + // Return success + return &response.EmbeddingMeta, nil +} + +// Embedding vector generation +func (model *model) Embedding(context.Context, string, ...llm.Opt) ([]float64, error) { + return nil, llm.ErrNotImplemented +} diff --git a/pkg/ollama/embedding_test.go b/pkg/ollama/embedding_test.go new file mode 100644 index 0000000..77c854f --- /dev/null +++ b/pkg/ollama/embedding_test.go @@ -0,0 +1,28 @@ +package ollama_test + +import ( + "context" + "os" + "testing" + + // Packages + opts "github.com/mutablelogic/go-client" + ollama "github.com/mutablelogic/go-llm/pkg/ollama" + assert "github.com/stretchr/testify/assert" +) + +func Test_embed_001(t *testing.T) { + client, err := ollama.New(GetEndpoint(t), opts.OptTrace(os.Stderr, true)) + if err != nil { + t.FailNow() + } + + t.Run("Embedding", func(t *testing.T) { + assert := assert.New(t) + embedding, err := client.GenerateEmbedding(context.TODO(), "qwen:0.5b", []string{"world"}) + if !assert.NoError(err) { + t.FailNow() + } + t.Log(embedding) + }) +} diff --git a/pkg/ollama/message.go b/pkg/ollama/message.go new file mode 100644 index 0000000..53efe76 --- /dev/null +++ b/pkg/ollama/message.go @@ -0,0 +1,36 @@ +package ollama + +import ( + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +// Chat Message +type MessageMeta struct { + Role string `json:"role"` + Content string `json:"content,omitempty"` + FunctionName string `json:"name,omitempty"` // Function name for a tool result + Images []Data `json:"images,omitempty"` // Image attachments + ToolCalls []ToolCall `json:"tool_calls,omitempty"` // Tool calls from the assistant +} + +type ToolCall struct { + Function ToolCallFunction `json:"function"` +} + +type ToolCallFunction struct { + Index int `json:"index,omitempty"` + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` +} + +// Data represents the raw binary data of an image file. +type Data []byte + +// ToolFunction +type ToolFunction struct { + Type string `json:"type"` // function + Function llm.Tool `json:"function"` +} diff --git a/pkg/ollama/model.go b/pkg/ollama/model.go new file mode 100644 index 0000000..246d68d --- /dev/null +++ b/pkg/ollama/model.go @@ -0,0 +1,244 @@ +package ollama + +import ( + "context" + "encoding/json" + "net/http" + "time" + + // Packages + client "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +// model is the implementation of the llm.Model interface +type model struct { + client *Client + ModelMeta +} + +var _ llm.Model = (*model)(nil) + +// ModelMeta is the metadata for an ollama model +type ModelMeta struct { + Name string `json:"name"` + Model string `json:"model,omitempty"` + ModifiedAt time.Time `json:"modified_at"` + Size int64 `json:"size,omitempty"` + Digest string `json:"digest,omitempty"` + Details ModelDetails `json:"details"` + File string `json:"modelfile,omitempty"` + Parameters string `json:"parameters,omitempty"` + Template string `json:"template,omitempty"` + Info ModelInfo `json:"model_info,omitempty"` +} + +// ModelDetails are the details of the model +type ModelDetails struct { + ParentModel string `json:"parent_model,omitempty"` + Format string `json:"format"` + Family string `json:"family"` + Families []string `json:"families"` + ParameterSize string `json:"parameter_size"` + QuantizationLevel string `json:"quantization_level"` +} + +// ModelInfo provides additional model parameters +type ModelInfo map[string]any + +// PullStatus provides the status of a pull operation in a callback function +type PullStatus struct { + Status string `json:"status"` + DigestName string `json:"digest,omitempty"` + TotalBytes int64 `json:"total,omitempty"` + CompletedBytes int64 `json:"completed,omitempty"` +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (m model) String() string { + data, err := json.MarshalIndent(m.ModelMeta, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +func (m PullStatus) String() string { + data, err := json.MarshalIndent(m, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// INTERFACE IMPLEMENTATION + +func (m model) Name() string { + return m.ModelMeta.Name +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Agent interface +func (ollama *Client) Models(ctx context.Context) ([]llm.Model, error) { + return ollama.ListModels(ctx) +} + +// List models +func (ollama *Client) ListModels(ctx context.Context) ([]llm.Model, error) { + type respListModel struct { + Models []*model `json:"models"` + } + + // Send the request + var response respListModel + if err := ollama.DoWithContext(ctx, nil, &response, client.OptPath("tags")); err != nil { + return nil, err + } + + // Convert to llm.Model + result := make([]llm.Model, 0, len(response.Models)) + for _, model := range response.Models { + model.client = ollama + result = append(result, model) + } + + // Return models + return result, nil +} + +// List running models +func (ollama *Client) ListRunningModels(ctx context.Context) ([]llm.Model, error) { + type respListModel struct { + Models []*model `json:"models"` + } + + // Send the request + var response respListModel + if err := ollama.DoWithContext(ctx, nil, &response, client.OptPath("ps")); err != nil { + return nil, err + } + + // Convert to llm.Model + result := make([]llm.Model, 0, len(response.Models)) + for _, model := range response.Models { + model.client = ollama + result = append(result, model) + } + + // Return models + return result, nil +} + +// Get model details +func (ollama *Client) GetModel(ctx context.Context, name string) (llm.Model, error) { + type reqGetModel struct { + Model string `json:"model"` + } + + // Request + req, err := client.NewJSONRequest(reqGetModel{ + Model: name, + }) + if err != nil { + return nil, err + } + + // Response + var response model + if err := ollama.DoWithContext(ctx, req, &response, client.OptPath("show")); err != nil { + return nil, err + } else { + response.client = ollama + response.ModelMeta.Name = name + } + + // Return success + return &response, nil +} + +// Copy a local model by name +func (ollama *Client) CopyModel(ctx context.Context, source, destination string) error { + type reqCopyModel struct { + Source string `json:"source"` + Destination string `json:"destination"` + } + + // Request + req, err := client.NewJSONRequest(reqCopyModel{ + Source: source, + Destination: destination, + }) + if err != nil { + return err + } + + // Response + return ollama.Do(req, nil, client.OptPath("copy")) +} + +// Delete a local model by name +func (ollama *Client) DeleteModel(ctx context.Context, name string) error { + type reqGetModel struct { + Model string `json:"model"` + } + + // Request + req, err := client.NewJSONRequestEx(http.MethodDelete, reqGetModel{ + Model: name, + }, client.ContentTypeAny) + if err != nil { + return err + } + + // Response + return ollama.Do(req, nil, client.OptPath("delete")) +} + +// Pull a remote model locally +func (ollama *Client) PullModel(ctx context.Context, name string, opts ...llm.Opt) (llm.Model, error) { + type reqPullModel struct { + Model string `json:"model"` + Insecure bool `json:"insecure,omitempty"` + Stream bool `json:"stream"` + } + + // Apply options + opt, err := llm.ApplyOpts(opts...) + if err != nil { + return nil, err + } + + // Request + req, err := client.NewJSONRequest(reqPullModel{ + Model: name, + Stream: optPullStatus(opt) != nil, + Insecure: optInsecure(opt), + }) + if err != nil { + return nil, err + } + + // Response + var response PullStatus + if err := ollama.DoWithContext(ctx, req, &response, client.OptPath("pull"), client.OptNoTimeout(), client.OptJsonStreamCallback(func(v any) error { + if v, ok := v.(*PullStatus); ok && v != nil { + if fn := optPullStatus(opt); fn != nil { + fn(v) + } + } + return nil + })); err != nil { + return nil, err + } + + // Return success + return ollama.GetModel(ctx, name) +} diff --git a/pkg/ollama/model_test.go b/pkg/ollama/model_test.go new file mode 100644 index 0000000..db14c9d --- /dev/null +++ b/pkg/ollama/model_test.go @@ -0,0 +1,74 @@ +package ollama_test + +import ( + "context" + "os" + "testing" + + // Packages + opts "github.com/mutablelogic/go-client" + ollama "github.com/mutablelogic/go-llm/pkg/ollama" + assert "github.com/stretchr/testify/assert" +) + +func Test_model_001(t *testing.T) { + client, err := ollama.New(GetEndpoint(t), opts.OptTrace(os.Stderr, true)) + if err != nil { + t.FailNow() + } + + var names []string + t.Run("Models", func(t *testing.T) { + assert := assert.New(t) + models, err := client.Models(context.TODO()) + if !assert.NoError(err) { + t.FailNow() + } + assert.NotNil(models) + for _, model := range models { + names = append(names, model.Name()) + } + }) + + t.Run("Model", func(t *testing.T) { + assert := assert.New(t) + for _, name := range names { + model, err := client.GetModel(context.TODO(), name) + if !assert.NoError(err) { + t.FailNow() + } + t.Log(model) + } + }) + + t.Run("PullModel", func(t *testing.T) { + assert := assert.New(t) + model, err := client.PullModel(context.TODO(), "qwen:0.5b", ollama.WithPullStatus(func(status *ollama.PullStatus) { + t.Log(status) + })) + if !assert.NoError(err) { + t.FailNow() + } + assert.NotNil(model) + }) + + t.Run("CopyModel", func(t2 *testing.T) { + assert := assert.New(t) + err := client.CopyModel(context.TODO(), "qwen:0.5b", t.Name()) + if !assert.NoError(err) { + t.FailNow() + } + }) + + t.Run("DeleteModel", func(t2 *testing.T) { + assert := assert.New(t) + _, err = client.GetModel(context.TODO(), t.Name()) + if !assert.NoError(err) { + t.FailNow() + } + err := client.DeleteModel(context.TODO(), t.Name()) + if !assert.NoError(err) { + t.FailNow() + } + }) +} diff --git a/pkg/ollama/opt.go b/pkg/ollama/opt.go new file mode 100644 index 0000000..f5a28d0 --- /dev/null +++ b/pkg/ollama/opt.go @@ -0,0 +1,145 @@ +package ollama + +import ( + "time" + + // Packages + llm "github.com/mutablelogic/go-llm" +) + +//////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Pull Model: Allow insecure connections for pulling models. +func WithInsecure() llm.Opt { + return func(o *llm.Opts) error { + o.Set("insecure", true) + return nil + } +} + +// Embeddings: Does not truncate the end of each input to fit within context length. Returns error if context length is exceeded. +func WithTruncate(v bool) llm.Opt { + return func(o *llm.Opts) error { + o.Set("truncate", v) + return nil + } +} + +// Embeddings & Chat: Controls how long the model will stay loaded into memory following the request. +func WithKeepAlive(v time.Duration) llm.Opt { + return func(o *llm.Opts) error { + if v <= 0 { + return llm.ErrBadParameter.With("keepalive must be greater than zero") + } + o.Set("keepalive", v) + return nil + } +} + +// Pull Model: Stream the response as it is received. +func WithPullStatus(fn func(*PullStatus)) llm.Opt { + return func(o *llm.Opts) error { + o.Set("pullstatus", fn) + return nil + } +} + +// Embeddings & Chat: model-specific options. +func WithOption(key string, value any) llm.Opt { + return func(o *llm.Opts) error { + if opts, ok := o.Get("options").(map[string]any); !ok { + o.Set("options", map[string]any{key: value}) + } else { + opts[key] = value + } + return nil + } +} + +//////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func optInsecure(opts *llm.Opts) bool { + return opts.GetBool("insecure") +} + +func optTruncate(opts *llm.Opts) *bool { + if !opts.Has("truncate") { + return nil + } + v := opts.GetBool("truncate") + return &v +} + +func optPullStatus(opts *llm.Opts) func(*PullStatus) { + if fn, ok := opts.Get("pullstatus").(func(*PullStatus)); ok && fn != nil { + return fn + } + return nil +} + +func optSystemPrompt(opts *llm.Opts) string { + return opts.SystemPrompt() +} + +func optTools(agent *Client, opts *llm.Opts) []ToolFunction { + toolkit := opts.ToolKit() + if toolkit == nil { + return nil + } + tools := toolkit.Tools(agent) + result := make([]ToolFunction, 0, len(tools)) + for _, tool := range tools { + result = append(result, ToolFunction{ + Type: "function", + Function: tool, + }) + } + return result +} + +func optFormat(opts *llm.Opts) string { + return opts.GetString("format") +} + +func optOptions(opts *llm.Opts) map[string]any { + result := make(map[string]any) + if o, ok := opts.Get("options").(map[string]any); ok { + for k, v := range o { + result[k] = v + } + } + + // copy across temperature, top_p and top_k + if opts.Has("temperature") { + result["temperature"] = opts.Get("temperature") + } + if opts.Has("top_p") { + result["top_p"] = opts.Get("top_p") + } + if opts.Has("top_k") { + result["top_k"] = opts.Get("top_k") + } + + // Return result + return result +} + +func optStream(agent *Client, opts *llm.Opts) bool { + // Streaming only if there is a stream function and no tools + toolkit := opts.ToolKit() + if toolkit != nil { + if tools := toolkit.Tools(agent); len(tools) > 0 { + return false + } + } + return opts.StreamFn() != nil +} + +func optKeepAlive(opts *llm.Opts) *time.Duration { + if v := opts.GetDuration("keepalive"); v > 0 { + return &v + } + return nil +} diff --git a/pkg/ollama/session.go b/pkg/ollama/session.go new file mode 100644 index 0000000..867f60a --- /dev/null +++ b/pkg/ollama/session.go @@ -0,0 +1,196 @@ +package ollama + +import ( + "context" + "encoding/json" + "fmt" + + // Packages + llm "github.com/mutablelogic/go-llm" + "github.com/mutablelogic/go-llm/pkg/tool" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +// Implementation of a message session, which is a sequence of messages +type session struct { + opts []llm.Opt + model *model + seq []*MessageMeta +} + +var _ llm.Context = (*session)(nil) + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +// Create a new empty context +func (model *model) Context(opts ...llm.Opt) llm.Context { + return &session{ + model: model, + opts: opts, + } +} + +// Create a new context with a user prompt +func (model *model) UserPrompt(prompt string, opts ...llm.Opt) llm.Context { + context := model.Context(opts...) + context.(*session).seq = append(context.(*session).seq, &MessageMeta{ + Role: "user", + Content: prompt, + }) + return context +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (session session) String() string { + var data []byte + var err error + if len(session.seq) == 1 { + data, err = json.MarshalIndent(session.seq[0], "", " ") + } else { + data, err = json.MarshalIndent(session.seq, "", " ") + } + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Generate a response from a user prompt (with attachments) +func (s *session) FromUser(ctx context.Context, prompt string, opts ...llm.Opt) error { + // Append the user prompt + if user, err := userPrompt(prompt, opts...); err != nil { + return err + } else { + s.seq = append(s.seq, user) + } + + // The options come from the session options and the user options + chatopts := make([]llm.Opt, 0, len(s.opts)+len(opts)) + chatopts = append(chatopts, s.opts...) + chatopts = append(chatopts, opts...) + + // Call the 'chat' method + client := s.model.client + r, err := client.Chat(ctx, s, chatopts...) + if err != nil { + return err + } else { + s.seq = append(s.seq, &r.Message) + } + + // Return success + return nil +} + +// Generate a response from a tool calling result +func (s *session) FromTool(ctx context.Context, results ...llm.ToolResult) error { + if len(results) == 0 { + return llm.ErrConflict.Withf("No tool results") + } + + // Append the tool results + for _, result := range results { + if message, err := toolResult(result); err != nil { + return err + } else { + s.seq = append(s.seq, message) + } + } + + // Call the 'chat' method + r, err := s.model.client.Chat(ctx, s, s.opts...) + if err != nil { + return err + } else { + s.seq = append(s.seq, &r.Message) + } + + // Return success + return nil +} + +// Return the role of the last message +func (session *session) Role() string { + if len(session.seq) == 0 { + return "" + } + return session.seq[len(session.seq)-1].Role +} + +// Return the text of the last message +func (session *session) Text() string { + if len(session.seq) == 0 { + return "" + } + return session.seq[len(session.seq)-1].Content +} + +// Return the tool calls of the last message +func (session *session) ToolCalls() []llm.ToolCall { + // Sanity check for tool call + if len(session.seq) == 0 { + return nil + } + meta := session.seq[len(session.seq)-1] + if meta.Role != "assistant" { + return nil + } + + // Gather tool calls + var result []llm.ToolCall + for _, call := range meta.ToolCalls { + result = append(result, tool.NewCall(fmt.Sprint(call.Function.Index), call.Function.Name, call.Function.Arguments)) + } + return result +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func userPrompt(prompt string, opts ...llm.Opt) (*MessageMeta, error) { + // Apply options for attachments + opt, err := llm.ApplyOpts(opts...) + if err != nil { + return nil, err + } + + // Create a new message + var meta MessageMeta + meta.Role = "user" + meta.Content = prompt + + if attachments := opt.Attachments(); len(attachments) > 0 { + meta.Images = make([]Data, len(attachments)) + for i, attachment := range attachments { + meta.Images[i] = attachment.Data() + } + } + + // Return success + return &meta, nil +} + +func toolResult(result llm.ToolResult) (*MessageMeta, error) { + // Turn result into JSON + data, err := json.Marshal(result.Value()) + if err != nil { + return nil, err + } + + // Create a new message + var meta MessageMeta + meta.Role = "tool" + meta.FunctionName = result.Call().Name() + meta.Content = string(data) + + // Return success + return &meta, nil +} diff --git a/pkg/ollama/session_test.go b/pkg/ollama/session_test.go new file mode 100644 index 0000000..e4df9d4 --- /dev/null +++ b/pkg/ollama/session_test.go @@ -0,0 +1,90 @@ +package ollama_test + +import ( + "context" + "os" + "testing" + + // Packages + opts "github.com/mutablelogic/go-client" + llm "github.com/mutablelogic/go-llm" + ollama "github.com/mutablelogic/go-llm/pkg/ollama" + "github.com/mutablelogic/go-llm/pkg/tool" + assert "github.com/stretchr/testify/assert" +) + +func Test_session_001(t *testing.T) { + client, err := ollama.New(GetEndpoint(t), opts.OptTrace(os.Stderr, true)) + if err != nil { + t.FailNow() + } + + // Pull the model + model, err := client.PullModel(context.TODO(), "qwen:0.5b") + if err != nil { + t.FailNow() + } + + // Session with a single user prompt - streaming + t.Run("stream", func(t *testing.T) { + assert := assert.New(t) + session := model.Context(llm.WithStream(func(stream llm.Context) { + t.Log("SESSION DELTA", stream) + })) + assert.NotNil(session) + + err := session.FromUser(context.TODO(), "Why is the grass green?") + if !assert.NoError(err) { + t.FailNow() + } + assert.Equal("assistant", session.Role()) + assert.NotEmpty(session.Text()) + }) + + // Session with a single user prompt - not streaming + t.Run("nostream", func(t *testing.T) { + assert := assert.New(t) + session := model.Context() + assert.NotNil(session) + + err := session.FromUser(context.TODO(), "Why is the sky blue?") + if !assert.NoError(err) { + t.FailNow() + } + assert.Equal("assistant", session.Role()) + assert.NotEmpty(session.Text()) + }) +} + +func Test_session_002(t *testing.T) { + client, err := ollama.New(GetEndpoint(t), opts.OptTrace(os.Stderr, true)) + if err != nil { + t.FailNow() + } + + // Pull the model + model, err := client.PullModel(context.TODO(), "llama3.2") + if err != nil { + t.FailNow() + } + + // Make a toolkit + toolkit := tool.NewToolKit() + if err := toolkit.Register(new(weather)); err != nil { + t.FailNow() + } + + // Session with a tool call + t.Run("toolcall", func(t *testing.T) { + assert := assert.New(t) + + session := model.Context(llm.WithToolKit(toolkit)) + assert.NotNil(session) + + err = session.FromUser(context.TODO(), "What is today's weather in Berlin?", llm.WithTemperature(0.5)) + if !assert.NoError(err) { + t.FailNow() + } + t.Log(session) + }) +} diff --git a/pkg/ollama/testdata/guggenheim.jpg b/pkg/ollama/testdata/guggenheim.jpg new file mode 100644 index 0000000..7e16517 Binary files /dev/null and b/pkg/ollama/testdata/guggenheim.jpg differ diff --git a/pkg/tool/call.go b/pkg/tool/call.go new file mode 100644 index 0000000..47cfdf0 --- /dev/null +++ b/pkg/tool/call.go @@ -0,0 +1,67 @@ +package tool + +import ( + // Packages + + "encoding/json" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type CallMeta struct { + Name string `json:"name"` + Id string `json:"id,omitempty"` + Input map[string]any `json:"input,omitempty"` +} + +type call struct { + meta CallMeta +} + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +func NewCall(id, name string, input map[string]any) *call { + return &call{ + meta: CallMeta{ + Name: name, + Id: id, + Input: input, + }, + } +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (t call) MarshalJSON() ([]byte, error) { + return json.Marshal(t.meta) +} + +func (t call) String() string { + data, err := json.MarshalIndent(t.meta, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +func (t call) Name() string { + return t.meta.Name +} + +func (t call) Id() string { + return t.meta.Id +} + +func (t call) Decode(v any) error { + if data, err := json.Marshal(t.meta.Input); err != nil { + return err + } else { + return json.Unmarshal(data, v) + } +} diff --git a/pkg/tool/old/tool.go_old b/pkg/tool/old/tool.go_old new file mode 100644 index 0000000..2da2bee --- /dev/null +++ b/pkg/tool/old/tool.go_old @@ -0,0 +1,220 @@ +package ollama + +import ( + "encoding/json" + "errors" + "fmt" + "reflect" + "strings" + + // Packages + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type Tool struct { + Type string `json:"type"` + Function ToolFunction `json:"function"` +} + +type ToolFunction struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters struct { + Type string `json:"type,omitempty"` + Required []string `json:"required,omitempty"` + Properties map[string]ToolParameter `json:"properties,omitempty"` + } `json:"parameters"` + proto reflect.Type // Prototype for parameter return +} + +type ToolParameter struct { + Name string `json:"-"` + Type string `json:"type"` + Description string `json:"description,omitempty"` + Enum []string `json:"enum,omitempty"` + required bool + index []int // Field index into prototype for setting a field +} + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +// Return a tool, or panic if there is an error +func MustTool(name, description string, params any) *Tool { + tool, err := NewTool(name, description, params) + if err != nil { + panic(err) + } + return tool +} + +// Return a new tool definition +func NewTool(name, description string, params any) (*Tool, error) { + tool := Tool{ + Type: "function", + Function: ToolFunction{Name: name, Description: description, proto: reflect.TypeOf(params)}, + } + + // Add parameters + tool.Function.Parameters.Type = "object" + if params, err := paramsFor(params); err != nil { + return nil, err + } else { + tool.Function.Parameters.Required = make([]string, 0, len(params)) + tool.Function.Parameters.Properties = make(map[string]ToolParameter, len(params)) + for _, param := range params { + if _, exists := tool.Function.Parameters.Properties[param.Name]; exists { + return nil, llm.ErrConflict.Withf("parameter %q already exists", param.Name) + } else { + tool.Function.Parameters.Properties[param.Name] = param + } + if param.required { + tool.Function.Parameters.Required = append(tool.Function.Parameters.Required, param.Name) + } + } + } + + // Return success + return &tool, nil +} + +// Return a new tool call +func NewToolCall(v ToolCall) *ToolCallFunction { + return &v.Function +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (t Tool) String() string { + data, err := json.MarshalIndent(t, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +func (t *Tool) Params(call ToolCall) (any, error) { + if call.Function.Name != t.Function.Name { + return nil, llm.ErrBadParameter.Withf("invalid function %q, expected %q", call.Function.Name, t.Function.Name) + } + + // Create parameters + params := reflect.New(t.Function.proto).Elem() + + // Iterate over arguments + var result error + for name, value := range call.Function.Arguments { + param, exists := t.Function.Parameters.Properties[name] + if !exists { + return nil, llm.ErrBadParameter.Withf("invalid argument %q", name) + } + result = errors.Join(result, paramSet(params.FieldByIndex(param.index), value)) + } + + // Return any errors + return params.Interface(), result +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +// Return tool parameters from a struct +func paramsFor(params any) ([]ToolParameter, error) { + if params == nil { + return []ToolParameter{}, nil + } + rt := reflect.TypeOf(params) + if rt.Kind() == reflect.Ptr { + rt = rt.Elem() + } + if rt.Kind() != reflect.Struct { + return nil, llm.ErrBadParameter.With("params must be a struct") + } + + // Iterate over fields + fields := reflect.VisibleFields(rt) + result := make([]ToolParameter, 0, len(fields)) + for _, field := range fields { + if param, err := paramFor(field); err != nil { + return nil, err + } else { + result = append(result, param) + } + } + + // Return success + return result, nil +} + +// Return tool parameters from a struct field +func paramFor(field reflect.StructField) (ToolParameter, error) { + // Name + name := field.Tag.Get("name") + if name == "" { + name = field.Name + } + + // Type + typ, err := paramType(field) + if err != nil { + return ToolParameter{}, err + } + + // Required + _, required := field.Tag.Lookup("required") + + // Enum + enum := []string{} + if enum_ := field.Tag.Get("enum"); enum_ != "" { + enum = strings.Split(enum_, ",") + } + + // Return success + return ToolParameter{ + Name: field.Name, + Type: typ, + Description: field.Tag.Get("help"), + Enum: enum, + required: required, + index: field.Index, + }, nil +} + +var ( + typeString = reflect.TypeOf("") + typeUint = reflect.TypeOf(uint(0)) + typeInt = reflect.TypeOf(int(0)) + typeFloat64 = reflect.TypeOf(float64(0)) + typeFloat32 = reflect.TypeOf(float32(0)) +) + +// Return parameter type from a struct field +func paramType(field reflect.StructField) (string, error) { + t := field.Type + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + switch field.Type { + case typeString: + return "string", nil + case typeUint, typeInt: + return "integer", nil + case typeFloat64, typeFloat32: + return "number", nil + default: + return "", llm.ErrBadParameter.Withf("unsupported type %v for field %q", field.Type, field.Name) + } +} + +// Set a field parameter +func paramSet(field reflect.Value, v any) error { + fmt.Println("TODO", field, "=>", v) + return nil +} diff --git a/pkg/tool/old/tool.go_old_old b/pkg/tool/old/tool.go_old_old new file mode 100644 index 0000000..2d0a24b --- /dev/null +++ b/pkg/tool/old/tool.go_old_old @@ -0,0 +1,216 @@ +package anthropic + +import ( + "encoding/json" + "reflect" + "strings" + + // Packages + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type Tool struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters struct { + Type string `json:"type,omitempty"` + Required []string `json:"required,omitempty"` + Properties map[string]ToolParameter `json:"properties,omitempty"` + } `json:"input_schema"` + proto reflect.Type // Prototype for parameter return +} + +type ToolParameter struct { + Name string `json:"-"` + Type string `json:"type"` + Description string `json:"description,omitempty"` + Enum []string `json:"enum,omitempty"` + required bool + index []int // Field index into prototype for setting a field +} + +type toolcall struct { + ContentTool +} + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +// Return a tool, or panic if there is an error +func MustTool(name, description string, params any) *Tool { + tool, err := NewTool(name, description, params) + if err != nil { + panic(err) + } + return tool +} + +// Return a new tool definition +func NewTool(name, description string, params any) (*Tool, error) { + tool := Tool{ + Name: name, + Description: description, + proto: reflect.TypeOf(params), + } + + // Add parameters + tool.Parameters.Type = "object" + toolparams, err := paramsFor(params) + if err != nil { + return nil, err + } + + // Set parameters + tool.Parameters.Required = make([]string, 0, len(toolparams)) + tool.Parameters.Properties = make(map[string]ToolParameter, len(toolparams)) + for _, param := range toolparams { + if _, exists := tool.Parameters.Properties[param.Name]; exists { + return nil, llm.ErrConflict.Withf("parameter %q already exists", param.Name) + } else { + tool.Parameters.Properties[param.Name] = param + } + if param.required { + tool.Parameters.Required = append(tool.Parameters.Required, param.Name) + } + } + + // Return success + return &tool, nil +} + +// Return a new tool call from a content parameter +func NewToolCall(content *Content) *toolcall { + if content == nil || content.ContentTool.Id == "" || content.ContentTool.Name == "" { + return nil + } + return &toolcall{content.ContentTool} +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (t Tool) String() string { + data, err := json.MarshalIndent(t, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +func (t toolcall) String() string { + data, err := json.MarshalIndent(t, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +func (t *toolcall) Name() string { + return t.ContentTool.Name +} + +func (t *toolcall) Id() string { + return t.ContentTool.Id +} + +func (t *toolcall) Params() any { + // TODO: Convert + return t.ContentTool.Input +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +// Return tool parameters from a struct +func paramsFor(params any) ([]ToolParameter, error) { + if params == nil { + return []ToolParameter{}, nil + } + rt := reflect.TypeOf(params) + if rt.Kind() == reflect.Ptr { + rt = rt.Elem() + } + if rt.Kind() != reflect.Struct { + return nil, llm.ErrBadParameter.With("params must be a struct") + } + + // Iterate over fields + fields := reflect.VisibleFields(rt) + result := make([]ToolParameter, 0, len(fields)) + for _, field := range fields { + if param, err := paramFor(field); err != nil { + return nil, err + } else { + result = append(result, param) + } + } + + // Return success + return result, nil +} + +// Return tool parameters from a struct field +func paramFor(field reflect.StructField) (ToolParameter, error) { + // Name + name := field.Tag.Get("name") + if name == "" { + name = field.Name + } + + // Type + typ, err := paramType(field) + if err != nil { + return ToolParameter{}, err + } + + // Required + _, required := field.Tag.Lookup("required") + + // Enum + enum := []string{} + if enum_ := field.Tag.Get("enum"); enum_ != "" { + enum = strings.Split(enum_, ",") + } + + // Return success + return ToolParameter{ + Name: field.Name, + Type: typ, + Description: field.Tag.Get("help"), + Enum: enum, + required: required, + index: field.Index, + }, nil +} + +var ( + typeString = reflect.TypeOf("") + typeUint = reflect.TypeOf(uint(0)) + typeInt = reflect.TypeOf(int(0)) + typeFloat64 = reflect.TypeOf(float64(0)) + typeFloat32 = reflect.TypeOf(float32(0)) +) + +// Return parameter type from a struct field +func paramType(field reflect.StructField) (string, error) { + t := field.Type + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + switch field.Type { + case typeString: + return "string", nil + case typeUint, typeInt: + return "integer", nil + case typeFloat64, typeFloat32: + return "number", nil + default: + return "", llm.ErrBadParameter.Withf("unsupported type %v for field %q", field.Type, field.Name) + } +} diff --git a/pkg/tool/old/tool_test.go_old b/pkg/tool/old/tool_test.go_old new file mode 100644 index 0000000..f4d1d30 --- /dev/null +++ b/pkg/tool/old/tool_test.go_old @@ -0,0 +1,29 @@ +package ollama_test + +import ( + "testing" + + // Packagees + + ollama "github.com/mutablelogic/go-llm/pkg/ollama" +) + +func Test_tool_001(t *testing.T) { + tool, err := ollama.NewTool("test", "test_tool", struct{}{}) + if err != nil { + t.FailNow() + } + t.Log(tool) +} + +func Test_tool_002(t *testing.T) { + tool, err := ollama.NewTool("test", "test_tool", struct { + A string `help:"A string"` + B int `help:"An integer"` + C float64 `help:"A float" required:""` + }{}) + if err != nil { + t.FailNow() + } + t.Log(tool) +} diff --git a/pkg/tool/result.go b/pkg/tool/result.go new file mode 100644 index 0000000..d3ba19e --- /dev/null +++ b/pkg/tool/result.go @@ -0,0 +1,62 @@ +package tool + +import ( + // Packages + "encoding/json" + + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type ResultMeta struct { + Call llm.ToolCall `json:"call"` + Value any `json:"result"` +} + +type result struct { + meta ResultMeta +} + +var _ llm.ToolResult = (*result)(nil) + +/////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +func NewResult(call llm.ToolCall, value any) llm.ToolResult { + return &result{ + meta: ResultMeta{ + Call: call, + Value: value, + }, + } +} + +/////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (r result) MarshalJSON() ([]byte, error) { + return json.Marshal(r.meta) +} + +func (r result) String() string { + data, err := json.MarshalIndent(r.meta, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// The call associated with the result +func (r result) Call() llm.ToolCall { + return r.meta.Call +} + +// The result, which can be encoded into json +func (r result) Value() any { + return r.meta.Value +} diff --git a/pkg/tool/tool.go b/pkg/tool/tool.go new file mode 100644 index 0000000..5b22920 --- /dev/null +++ b/pkg/tool/tool.go @@ -0,0 +1,198 @@ +package tool + +import ( + "encoding/json" + "reflect" + "strings" + + // Packages + llm "github.com/mutablelogic/go-llm" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type tool struct { + llm.Tool `json:"-"` + ToolMeta +} + +var _ llm.Tool = (*tool)(nil) + +type ToolMeta struct { + Name string `json:"name"` + Description string `json:"description"` + + // Variation on how schema is output + Parameters *ToolParameters `json:"parameters,omitempty"` + InputSchema *ToolParameters `json:"input_schema,omitempty"` +} + +type ToolParameters struct { + Type string `json:"type,omitempty"` + Required []string `json:"required"` + Properties map[string]ToolParameter `json:"properties"` +} + +type ToolParameter struct { + Name string `json:"-"` + Type string `json:"type"` + Description string `json:"description,omitempty"` + Enum []string `json:"enum,omitempty"` + required bool + index []int // Field index into prototype for setting a field +} + +//////////////////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (t tool) String() string { + data, err := json.MarshalIndent(t.ToolMeta, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +func (t tool) Name() string { + return t.ToolMeta.Name +} + +func (t tool) Description() string { + return t.ToolMeta.Description +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +// Return tool parameters from a struct +func paramsFor(root []int, params any) ([]ToolParameter, error) { + if params == nil { + return []ToolParameter{}, nil + } + rt := reflect.TypeOf(params) + if rt.Kind() == reflect.Ptr { + rt = rt.Elem() + } + if rt.Kind() != reflect.Struct { + return nil, llm.ErrBadParameter.With("params must be a struct") + } + + return paramsForStruct(root, rt) +} + +func paramsForStruct(root []int, rt reflect.Type) ([]ToolParameter, error) { + result := make([]ToolParameter, 0, rt.NumField()) + + // Iterate over fields + for i := 0; i < rt.NumField(); i++ { + field := rt.Field(i) + + // Ignore unexported fields + name := fieldName(field) + if name == "" { + continue + } + + // Recurse into struct + ft := field.Type + if ft.Kind() == reflect.Ptr { + ft = field.Type.Elem() + } + if ft.Kind() == reflect.Struct { + if param, err := paramsForStruct(append(root, field.Index...), ft); err != nil { + return nil, err + } else { + result = append(result, param...) + } + continue + } + + // Determine parameter + if param, err := paramFor(root, field); err != nil { + return nil, err + } else { + result = append(result, param) + } + } + + // Return success + return result, nil +} + +// Return tool parameters from a struct field +func paramFor(root []int, field reflect.StructField) (ToolParameter, error) { + // Type + typ, err := paramType(field) + if err != nil { + return ToolParameter{}, err + } + + // Required + _, required := field.Tag.Lookup("required") + + // Enum + enum := []string{} + if enum_ := field.Tag.Get("enum"); enum_ != "" { + for _, e := range strings.Split(enum_, ",") { + enum = append(enum, strings.TrimSpace(e)) + } + } + + // Return success + return ToolParameter{ + Name: fieldName(field), + Type: typ, + Description: field.Tag.Get("help"), + Enum: enum, + required: required, + index: append(root, field.Index...), + }, nil +} + +// Return the name field, or empty name if field +// should be ignored +func fieldName(field reflect.StructField) string { + name, exists := field.Tag.Lookup("name") + if !exists { + name, exists = field.Tag.Lookup("json") + if names := strings.Split(name, ","); exists && len(names) > 0 { + name = names[0] + } + } + if !exists { + name = field.Name + } else if name == "-" { + return "" + } + return name +} + +var ( + typeString = reflect.TypeOf("") + typeUint = reflect.TypeOf(uint(0)) + typeInt = reflect.TypeOf(int(0)) + typeFloat64 = reflect.TypeOf(float64(0)) + typeFloat32 = reflect.TypeOf(float32(0)) +) + +// Return parameter type from a struct field +func paramType(field reflect.StructField) (string, error) { + t := field.Type + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + switch field.Type { + case typeString: + return "string", nil + case typeUint, typeInt: + return "integer", nil + case typeFloat64, typeFloat32: + return "number", nil + default: + return "", llm.ErrBadParameter.Withf("unsupported type %v for field %q", field.Type, field.Name) + } +} diff --git a/pkg/tool/toolkit.go b/pkg/tool/toolkit.go new file mode 100644 index 0000000..6d9fd34 --- /dev/null +++ b/pkg/tool/toolkit.go @@ -0,0 +1,138 @@ +package tool + +import ( + "context" + "errors" + "sync" + + // Packages + llm "github.com/mutablelogic/go-llm" +) + +//////////////////////////////////////////////////////////////////////////////// +// TYPES + +// ToolKit represents a toolkit of tools +type ToolKit struct { + functions map[string]tool +} + +var _ llm.ToolKit = (*ToolKit)(nil) + +//////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +// Create a new empty toolkit for an agent +func NewToolKit() *ToolKit { + return &ToolKit{ + functions: make(map[string]tool), + } +} + +//////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Return all registered tools for a specific agent +func (kit *ToolKit) Tools(agent llm.Agent) []llm.Tool { + result := make([]llm.Tool, 0, len(kit.functions)) + for _, t := range kit.functions { + switch agent.Name() { + case "ollama": + t.InputSchema = nil + result = append(result, t) + default: + t.Parameters = nil + result = append(result, t) + } + } + return result +} + +// Register a tool in the toolkit +func (kit *ToolKit) Register(v llm.Tool) error { + if v == nil { + return llm.ErrBadParameter.With("tool cannot be nil") + } + + name := v.Name() + if _, exists := kit.functions[name]; exists { + return llm.ErrConflict.Withf("tool %q already exists", name) + } + + // Set the tool + t := tool{ + Tool: v, + ToolMeta: ToolMeta{ + Name: name, + Description: v.Description(), + }, + } + + // Determine parameters + toolparams, err := paramsFor(nil, v) + if err != nil { + return err + } + + // Add parameters + parameters := ToolParameters{ + Type: "object", + Required: make([]string, 0, len(toolparams)), + Properties: make(map[string]ToolParameter, len(toolparams)), + } + + // Set parameters + for _, param := range toolparams { + if _, exists := parameters.Properties[param.Name]; exists { + return llm.ErrConflict.Withf("parameter %q already exists", param.Name) + } else { + parameters.Properties[param.Name] = param + } + if param.required { + parameters.Required = append(parameters.Required, param.Name) + } + } + + t.Parameters = ¶meters + t.InputSchema = ¶meters + + // Add to toolkit + kit.functions[name] = t + + // Return success + return nil +} + +// Run calls a tool in the toolkit +func (kit *ToolKit) Run(ctx context.Context, calls ...llm.ToolCall) ([]llm.ToolResult, error) { + var wg sync.WaitGroup + var errs error + var toolresult []llm.ToolResult + + // TODO: Lock each tool so it can only be run in series (although different + // tools can be run in parallel) + for _, call := range calls { + wg.Add(1) + go func(call llm.ToolCall) { + defer wg.Done() + + // Get the tool and run it + name := call.Name() + if _, exists := kit.functions[name]; !exists { + errs = errors.Join(errs, llm.ErrNotFound.Withf("tool %q not found", name)) + } else if err := call.Decode(kit.functions[name].Tool); err != nil { + errs = errors.Join(errs, err) + } else if out, err := kit.functions[name].Tool.Run(ctx); err != nil { + errs = errors.Join(errs, err) + } else { + toolresult = append(toolresult, NewResult(call, out)) + } + }(call) + } + + // Wait for all calls to complete + wg.Wait() + + // Return any errors + return toolresult, errs +} diff --git a/toolkit.go b/toolkit.go new file mode 100644 index 0000000..e61d98d --- /dev/null +++ b/toolkit.go @@ -0,0 +1,54 @@ +package llm + +import ( + "context" +) + +//////////////////////////////////////////////////////////////////////////////// +// TYPES + +// ToolKit is a collection of tools +type ToolKit interface { + // Register a tool in the toolkit + Register(Tool) error + + // Return all the tools + Tools(Agent) []Tool + + // Run the tool calls in parallel and return the results + Run(context.Context, ...ToolCall) ([]ToolResult, error) +} + +// Definition of a tool +type Tool interface { + // The name of the tool + Name() string + + // The description of the tool + Description() string + + // Run the tool with a deadline and return the result + // TODO: Change 'any' to ToolResult + Run(context.Context) (any, error) +} + +// A call-out to a tool +type ToolCall interface { + // The tool name + Name() string + + // The tool identifier + Id() string + + // Decode the calling parameters + Decode(v any) error +} + +// Results from calling tools +type ToolResult interface { + // The call associated with the result + Call() ToolCall + + // The result, which can be encoded into json + Value() any +}