Skip to content

Commit

Permalink
WIP: openai support custom assistant/GPT
Browse files Browse the repository at this point in the history
  • Loading branch information
brainexe committed Jan 12, 2024
1 parent 24af0e9 commit 0aef7bc
Show file tree
Hide file tree
Showing 6 changed files with 329 additions and 13 deletions.
6 changes: 4 additions & 2 deletions command/openai/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
const (
apiHost = "https://api.openai.com"
apiCompletionURL = "/v1/chat/completions"
apiThreadsURL = "/v1/threads"
apiDalleGenerateImageURL = "/v1/images/generations"
)

Expand All @@ -25,14 +26,15 @@ var client = http.Client{
Timeout: 60 * time.Second,
}

func doRequest(cfg Config, apiEndpoint string, data []byte) (*http.Response, error) {
req, err := http.NewRequest("POST", cfg.APIHost+apiEndpoint, bytes.NewBuffer(data))
func doRequest(cfg Config, method string, apiEndpoint string, data []byte) (*http.Response, error) {
req, err := http.NewRequest(method, cfg.APIHost+apiEndpoint, bytes.NewBuffer(data))
if err != nil {
return nil, err
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+cfg.APIKey)
req.Header.Set("OpenAI-Beta", "assistants=v1")

return client.Do(req)
}
Expand Down
300 changes: 300 additions & 0 deletions command/openai/assistant.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
package openai

import (
"encoding/json"
"fmt"

Check failure on line 5 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

File is not `gofumpt`-ed (gofumpt)
"github.com/innogames/slack-bot/v2/bot/msg"
"github.com/innogames/slack-bot/v2/bot/storage"
"github.com/slack-go/slack"
"io"

Check failure on line 9 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

File is not `gofumpt`-ed (gofumpt)
"time"
)

// see https://platform.openai.com/docs/assistants/how-it-works

type assistantThreadResponse struct {
Id string `json:"id"`

Check warning on line 16 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

var-naming: struct field Id should be ID (revive)
}
type assistantStartRun struct {
AssistantId string `json:"assistant_id"`

Check warning on line 19 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

var-naming: struct field AssistantId should be AssistantID (revive)
}

type run struct {
Id string `json:"id"`

Check warning on line 23 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

var-naming: struct field Id should be ID (revive)
Status string `json:"status"`
ThreadId string `json:"thread_id"`

Check warning on line 25 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

var-naming: struct field ThreadId should be ThreadID (revive)
RequiredAction AssistantRequiredAction `json:"required_action"`
}

type AssistantRequiredAction struct {
Type string `json:"type"`
SubmitToolsOutputs struct {
ToolCalls []struct {
Id string `json:"id"`

Check warning on line 33 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

var-naming: struct field Id should be ID (revive)
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
} `json:"tool_calls"`
} `json:"submit_tool_outputs"`
}

type AssistantContent struct {
Type string `json:"type"`
Text struct {
Value string `json:"value"`
} `json:"text"`
}

func (c AssistantContent) GetText() string {
return c.Text.Value
}

type AssistantStartThreads struct {
Messages []ChatMessage `json:"messages"`
}

type AssistantChatMessage struct {
Id string `json:"id"`
Role string `json:"role"`
ChatMessage []AssistantContent `json:"content"`
RunId string `json:"run_id"`

Check warning on line 62 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

var-naming: struct field RunId should be RunID (revive)
}
type assistantFullResponse struct {
Data []AssistantChatMessage `json:"data"`
}

type AssistantToolsOutput struct {
ToolsOutput []struct {
ToolCallId string `json:"tool_call_id"`

Check warning on line 70 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

var-naming: struct field ToolCallId should be ToolCallID (revive)
Output string `json:"output"`
} `json:"tool_outputs"`
}

func (c *openaiCommand) callCustomGPT(messages []ChatMessage, identifier string, message msg.Ref, text string) {
c.AddReaction(":coffee:", message)
defer c.RemoveReaction(":coffee:", message)

messages = append(messages, ChatMessage{
Role: roleUser,
Content: text,
})

var threadId string

Check warning on line 84 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

var-naming: var threadId should be threadID (revive)
var err error
storage.Read("gpt-thread", identifier, &threadId)

Check failure on line 86 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value of `storage.Read` is not checked (errcheck)
if threadId == "" {
// start a new thread!
threadId, err = createAssistantThread(c.cfg, messages)
if err != nil {
c.ReplyError(message, err)
return
}
storage.Write("gpt-thread", identifier, threadId)

Check failure on line 94 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value of `storage.Write` is not checked (errcheck)
} else {
// attach slack messages to an existing thread
for _, newMessage := range messages {
// todo no API to bulk add?!
addMessage(c.cfg, threadId, newMessage)

Check failure on line 99 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value is not checked (errcheck)
}
}

// start the assistant and get a "run" object
run, err := assistantRun(c.cfg, threadId)
if err != nil {
c.ReplyError(message, err)
return
}

// wait till run is done or required more information from function calls!
// see https://platform.openai.com/docs/assistants/how-it-works/run-lifecycle
ticker := time.NewTicker(time.Second * 1)
defer ticker.Stop()
for range ticker.C {
run, err = getRun(c.cfg, run)
if err != nil || run.Status == "failed" || run.Status == "cancelled" || run.Status == "expired" {
c.ReplyError(message, fmt.Errorf("run failed with status %s", run.Status))
return
}

if run.Status == "completed" {
// we have the final answer!
break
}

if run.Status == "requires_action" {
// todo extract code!
fmt.Println(run.RequiredAction)
fmt.Println(run.RequiredAction.SubmitToolsOutputs)
tool := run.RequiredAction.SubmitToolsOutputs.ToolCalls[0]

var output string
if tool.Function.Name == "dall_image" {
// special function
prompt := tool.Function.Arguments
fmt.Println(prompt, "prompt")

images, _ := generateImages(c.cfg, prompt)
output = images[0].RevisedPrompt
go c.sendImageInSlack(images[0], message)

Check failure on line 140 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value of `c.sendImageInSlack` is not checked (errcheck)
} else {
output = "Ticket: Fix issue in feature XYZ, status = open" // todo call function
}

sendToolsOutput(c.cfg, run, tool.Id, output)

Check failure on line 145 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value is not checked (errcheck)

// wait for new tick, as the API is handling the new information now...
continue
}
}

// todo only fetch the new messages for this run
respMessages, _ := listMessages(c.cfg, threadId)
for _, m := range respMessages {
if m.RunId != run.Id {
continue
}
fmt.Println(m.ChatMessage)
if m.Role != roleAssistant {
continue
}

// reply in thread
c.SendMessage(
message,
m.ChatMessage[0].GetText(),
slack.MsgOptionTS(message.GetTimestamp()),
)
}
}

/*
func (c *openaiCommand) assistantUploadFile(cfg Config, file slack.File) error {
var buf bytes.Buffer
log.Infof("Downloading message attachment file %s", file.Name)
fmt.Println(file)
resp, err := doRequest(cfg, "POST", apiFilesURL, []byte("jolo"))
if err != nil {
return nil
}
r, _ := io.ReadAll(resp.Body)
fmt.Println(string(r))
return nil
}
*/

func assistantRun(cfg Config, threadId string) (*run, error) {

Check warning on line 191 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

var-naming: func parameter threadId should be threadID (revive)
fmt.Printf("run assistant %s\n", threadId)

assistantStartRun := assistantStartRun{
AssistantId: cfg.CustomGPT,
}

req, _ := json.Marshal(assistantStartRun)
resp, err := doRequest(cfg, "POST", apiThreadsURL+"/"+threadId+"/runs", req)
if err != nil {
return nil, err
}

run := &run{}
err = json.NewDecoder(resp.Body).Decode(run)
return run, err
}

func addMessage(cfg Config, threadId string, message ChatMessage) error {
fmt.Printf("add message to thread %s: %s\n", threadId, message)

req, _ := json.Marshal(message)
_, err := doRequest(cfg, "POST", apiThreadsURL+"/"+threadId+"/messages", req)

return err
}

func createAssistantThread(cfg Config, messages []ChatMessage) (string, error) {
fmt.Println("create thread")

req, _ := json.Marshal(AssistantStartThreads{
Messages: messages,
})
fmt.Println(string(req))
resp, err := doRequest(cfg, "POST", apiThreadsURL, req)
if err != nil {
return "", err
}
//r, _ := io.ReadAll(resp.Body)

Check failure on line 229 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

commentFormatting: put a space between `//` and comment text (gocritic)
//fmt.Println(string(r))
thread := assistantThreadResponse{}
err = json.NewDecoder(resp.Body).Decode(&thread)
if err != nil {
return "", err
}
fmt.Println(thread)

if thread.Id == "" {
return "", fmt.Errorf("failed to create thread")
}
return thread.Id, nil
}

func getRun(cfg Config, oldRun *run) (*run, error) {
fmt.Printf("get run %s %s\n", oldRun.ThreadId, oldRun.Id)
resp, err := doRequest(cfg, "GET", apiThreadsURL+"/"+oldRun.ThreadId+"/runs/"+oldRun.Id, nil)
if err != nil {
return oldRun, err
}

r, _ := io.ReadAll(resp.Body)
fmt.Println(string(r))

newRun := &run{}
err = json.Unmarshal(r, newRun)

return newRun, err
}

func listMessages(cfg Config, threadId string) ([]AssistantChatMessage, error) {
fmt.Printf("list messages %s \n", threadId)
resp, err := doRequest(cfg, "GET", apiThreadsURL+"/"+threadId+"/messages", nil)
if err != nil {
return []AssistantChatMessage{}, err
}

r, _ := io.ReadAll(resp.Body)
fmt.Println(string(r))

messages := assistantFullResponse{}
json.Unmarshal(r, &messages)

Check failure on line 271 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value of `json.Unmarshal` is not checked (errcheck)

return messages.Data, nil
}

func sendToolsOutput(cfg Config, run *run, callId string, output string) error {
fmt.Printf("send tools output %s %s %s\n", run.ThreadId, run.Id, callId)

req, _ := json.Marshal(AssistantToolsOutput{
ToolsOutput: []struct {
ToolCallId string `json:"tool_call_id"`
Output string `json:"output"`
}{
{
ToolCallId: callId,
Output: output,
},
},
})
fmt.Println(string(req))
resp, err := doRequest(cfg, "POST", apiThreadsURL+"/"+run.ThreadId+"/runs/"+run.Id+"/submit_tool_outputs", req)
if err != nil {
return err
}

r, _ := io.ReadAll(resp.Body)
fmt.Println(string(r))

return err
}
2 changes: 1 addition & 1 deletion command/openai/chatgpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func CallChatGPT(cfg Config, inputMessages []ChatMessage, stream bool) (<-chan s
Stream: stream,
Messages: inputMessages,
})
resp, err := doRequest(cfg, apiCompletionURL, jsonData)
resp, err := doRequest(cfg, "POST", apiCompletionURL, jsonData)
if err != nil {
messageUpdates <- err.Error()
return
Expand Down
18 changes: 15 additions & 3 deletions command/openai/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ func (c *openaiCommand) GetMatcher() matcher.Matcher {
// bot function which is called, when the user started a new conversation with openai/chatgpt
func (c *openaiCommand) newConversation(match matcher.Result, message msg.Message) {
text := match.GetString(util.FullMatch)
c.startConversation(message.MessageRef, text)
c.startConversation(message, text)
}

func (c *openaiCommand) startConversation(message msg.Ref, text string) bool {
messageHistory := make([]ChatMessage, 0)

if c.cfg.InitialSystemMessage != "" {
if c.cfg.InitialSystemMessage != "" && c.cfg.CustomGPT == "" {
messageHistory = append(messageHistory, ChatMessage{
Role: roleSystem,
Content: c.cfg.InitialSystemMessage,
Expand Down Expand Up @@ -135,7 +135,12 @@ func (c *openaiCommand) startConversation(message msg.Ref, text string) bool {
storageIdentifier = getIdentifier(message.GetChannel(), message.GetTimestamp())
}

c.callAndStore(messageHistory, storageIdentifier, message, text)
if c.cfg.CustomGPT != "" {
c.callCustomGPT(messageHistory, storageIdentifier, message, text)
} else {
// usual GPT-X model
c.callAndStore(messageHistory, storageIdentifier, message, text)
}
return true
}

Expand All @@ -149,6 +154,13 @@ func (c *openaiCommand) reply(message msg.Ref, text string) bool {
// Load the chat history from storage.
identifier := getIdentifier(message.GetChannel(), message.GetThread())

var threadId string

Check warning on line 157 in command/openai/command.go

View workflow job for this annotation

GitHub Actions / Lint

var-naming: var threadId should be threadID (revive)
storage.Read("gpt-thread", identifier, &threadId)

Check failure on line 158 in command/openai/command.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value of `storage.Read` is not checked (errcheck)
if threadId != "" {
c.callCustomGPT([]ChatMessage{}, identifier, message, text)
return true
}

var messages []ChatMessage
err := storage.Read(storageKey, identifier, &messages)
if err != nil || len(messages) == 0 {
Expand Down
2 changes: 2 additions & 0 deletions command/openai/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ type Config struct {
DalleModel string `mapstructure:"dalle_model"`
DalleImageSize string `mapstructure:"dalle_image_size"`
DalleNumberOfImages int `mapstructure:"dalle_number_of_images"`

CustomGPT string `mapstructure:"custom_gpt"`
}

// IsEnabled checks if token is set
Expand Down
Loading

0 comments on commit 0aef7bc

Please sign in to comment.