diff --git a/chat-history/Makefile b/chat-history/Makefile index 27100d98..09ae863d 100644 --- a/chat-history/Makefile +++ b/chat-history/Makefile @@ -14,6 +14,6 @@ clean: run: clean test build clear - CONFIG="config.json" DEV=true ./chat-history + CONFIG_FILES="chat_config.json,db_config.json" DEV=true ./chat-history diff --git a/chat-history/config/config.go b/chat-history/config/config.go index 3185635d..e067be00 100644 --- a/chat-history/config/config.go +++ b/chat-history/config/config.go @@ -2,7 +2,6 @@ package config import ( "encoding/json" - "fmt" "os" ) @@ -10,14 +9,20 @@ type LLMConfig struct { ModelName string `json:"model_name"` } -type DbConfig struct { - Port string `json:"apiPort"` - DbPath string `json:"dbPath"` - DbLogPath string `json:"dbLogPath"` - LogPath string `json:"logPath"` - // DbHostname string `json:"hostname"` - // Username string `json:"username"` - // Password string `json:"password"` +type ChatDbConfig struct { + Port string `json:"apiPort"` + DbPath string `json:"dbPath"` + DbLogPath string `json:"dbLogPath"` + LogPath string `json:"logPath"` + ConversationAccessRoles []string `json:"conversationAccessRoles"` +} + +type TgDbConfig struct { + Hostname string `json:"hostname"` + Username string `json:"username"` + Password string `json:"password"` + GsPort string `json:"gsPort"` + TgCloud bool `json:"tgCloud"` // GetToken string `json:"getToken"` // DefaultTimeout string `json:"default_timeout"` // DefaultMemThreshold string `json:"default_mem_threshold"` @@ -25,29 +30,55 @@ type DbConfig struct { } type Config struct { - DbConfig + ChatDbConfig + TgDbConfig // LLMConfig } -func LoadConfig(path string) (Config, error) { - var b []byte - if _, err := os.Stat(path); os.IsNotExist(err) { - // file doesn't exist read from env - cfg := os.Getenv("CONFIG") - if cfg == "" { - fmt.Println("CONFIG path is not found nor is the CONFIG json env variable defined") - os.Exit(1) +func LoadConfig(paths map[string]string) (Config, error) { + var config Config + + // Load database config + if dbConfigPath, ok := paths["chatdb"]; ok { + dbConfig, err := loadChatDbConfig(dbConfigPath) + if err != nil { + return Config{}, err } - b = []byte(cfg) - } else { - b, err = os.ReadFile(path) + config.ChatDbConfig = dbConfig + } + + // Load TigerGraph config + if tgConfigPath, ok := paths["tgdb"]; ok { + tgConfig, err := loadTgDbConfig(tgConfigPath) if err != nil { return Config{}, err } + config.TgDbConfig = tgConfig } - var cfg Config - json.Unmarshal(b, &cfg) + return config, nil +} - return cfg, nil +func loadChatDbConfig(path string) (ChatDbConfig, error) { + var dbConfig ChatDbConfig + b, err := os.ReadFile(path) + if err != nil { + return ChatDbConfig{}, err + } + if err := json.Unmarshal(b, &dbConfig); err != nil { + return ChatDbConfig{}, err + } + return dbConfig, nil +} + +func loadTgDbConfig(path string) (TgDbConfig, error) { + var tgConfig TgDbConfig + b, err := os.ReadFile(path) + if err != nil { + return TgDbConfig{}, err + } + if err := json.Unmarshal(b, &tgConfig); err != nil { + return TgDbConfig{}, err + } + return tgConfig, nil } diff --git a/chat-history/config/config_test.go b/chat-history/config/config_test.go index ea4c8c0c..9ddce8c8 100644 --- a/chat-history/config/config_test.go +++ b/chat-history/config/config_test.go @@ -7,41 +7,59 @@ import ( ) func TestLoadConfig(t *testing.T) { - pth := setup(t) - cfg, err := LoadConfig(pth) + chatConfigPath, tgConfigPath := setup(t) + + cfg, err := LoadConfig(map[string]string{ + "chatdb": chatConfigPath, + "tgdb": tgConfigPath, + }) if err != nil { t.Fatal(err) } - if cfg.Port != "8000" || - cfg.DbPath != "chats.db" || - cfg.DbLogPath != "db.log" || - cfg.LogPath != "requestLogs.jsonl" { - t.Fatalf("config is wrong, %v", cfg) + if cfg.ChatDbConfig.Port != "8002" || + cfg.ChatDbConfig.DbPath != "chats.db" || + cfg.ChatDbConfig.DbLogPath != "db.log" || + cfg.ChatDbConfig.LogPath != "requestLogs.jsonl" { + t.Fatalf("config is wrong, %v", cfg.ChatDbConfig) + } + + if cfg.TgDbConfig.Hostname != "https://tg-0cdef603-3760-41c3-af6f-41e95afc40de.us-east-1.i.tgcloud.io" || + cfg.TgDbConfig.GsPort != "14240" || + cfg.TgDbConfig.TgCloud != true { + t.Fatalf("TigerGraph config is wrong, %v", cfg.TgDbConfig) } } -func setup(t *testing.T) string { +func setup(t *testing.T) (string, string) { tmp := t.TempDir() - pth := fmt.Sprintf("%s/%s", tmp, "config.json") - dat := ` + chatConfigPath := fmt.Sprintf("%s/%s", tmp, "chat_config.json") + chatConfigData := ` { - "apiPort":"8000", - "hostname": "http://localhost:14240", - "dbPath": "chats.db", - "dbLogPath": "db.log", - "logPath": "requestLogs.jsonl", - "username": "tigergraph", - "password": "tigergraph", - "getToken": false, - "default_timeout": 300, - "default_mem_threshold": 5000, - "default_thread_limit": 8 + "apiPort":"8002", + "dbPath": "chats.db", + "dbLogPath": "db.log", + "logPath": "requestLogs.jsonl", + "conversationAccessRoles": ["superuser", "globaldesigner"] }` - err := os.WriteFile(pth, []byte(dat), 0644) - if err != nil { - t.Fatal("error setting up config.json") + + if err := os.WriteFile(chatConfigPath, []byte(chatConfigData), 0644); err != nil { + t.Fatal("error setting up chat_config.json") } - return pth + + tgConfigPath := fmt.Sprintf("%s/%s", tmp, "db_config.json") + tgConfigData := ` +{ + "hostname": "https://tg-0cdef603-3760-41c3-af6f-41e95afc40de.us-east-1.i.tgcloud.io", + "gsPort": "14240", + "username": "supportai", + "password": "supportai", + "tgCloud": true +}` + if err := os.WriteFile(tgConfigPath, []byte(tgConfigData), 0644); err != nil { + t.Fatal("error setting up tg_config.json") + } + + return chatConfigPath, tgConfigPath } diff --git a/chat-history/db/db.go b/chat-history/db/db.go index 9b7a0e10..71554076 100644 --- a/chat-history/db/db.go +++ b/chat-history/db/db.go @@ -131,15 +131,28 @@ func UpdateConversationById(message structs.Message) (*structs.Conversation, err return &convo, nil } +// GetAllMessages retrieves all messages from the database +func GetAllMessages() ([]structs.Message, error) { + var messages []structs.Message + + // Use GORM to query all messages + if err := db.Find(&messages).Error; err != nil { + return nil, err + } + + return messages, nil +} + func populateDB() { mu.Lock() defer mu.Unlock() // init convos conv1 := uuid.MustParse("601529eb-4927-4e24-b285-bd6b9519a951") + conv2 := uuid.MustParse("601529eb-4927-4e24-b285-bd6b9519a952") db.Create(&structs.Conversation{UserId: "sam_pull", ConversationId: conv1, Name: "conv1"}) - db.Create(&structs.Conversation{UserId: "sam_pull", ConversationId: uuid.New(), Name: "conv2"}) - db.Create(&structs.Conversation{UserId: "Miss_Take", ConversationId: uuid.New(), Name: "conv3"}) + db.Create(&structs.Conversation{UserId: "Miss_Take", ConversationId: conv2, Name: "conv2"}) + // db.Create(&structs.Conversation{UserId: "Miss_Take", ConversationId: uuid.New(), Name: "conv3"}) // add message to convos message := structs.Message{ @@ -152,8 +165,8 @@ func populateDB() { Feedback: structs.NoFeedback, Comment: "", } - db.Create(&message) + m2 := structs.Message{ ConversationId: conv1, MessageId: uuid.New(), @@ -165,4 +178,16 @@ func populateDB() { Comment: "", } db.Create(&m2) + + m3 := structs.Message{ + ConversationId: conv2, + MessageId: uuid.New(), + ParentId: &message.MessageId, + ModelName: "GPT-4o", + Content: "How many transactions?", + Role: structs.SystemRole, + Feedback: structs.NoFeedback, + Comment: "", + } + db.Create(&m3) } diff --git a/chat-history/db/db_test.go b/chat-history/db/db_test.go index ed50f17c..ac7a63c4 100644 --- a/chat-history/db/db_test.go +++ b/chat-history/db/db_test.go @@ -213,6 +213,29 @@ func TestParallelWrites(t *testing.T) { } } +func TestGetAllMessages(t *testing.T) { + setupTest(t, true) + + messages, err := GetAllMessages() + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // Ensure that messages are returned + if len(messages) == 0 { + t.Fatalf("Expected some messages, got none") + } + + // Validate the structure of the messages + for _, m := range messages { + if uuid.Validate(m.ConversationId.String()) != nil || + uuid.Validate(m.MessageId.String()) != nil || + (m.Role != "system" && m.Role != "user") { + t.Fatalf("Invaid message structure: %v", m) + } + } +} + /* helper functions */ diff --git a/chat-history/go.mod b/chat-history/go.mod index f132c6bc..cc9328ab 100644 --- a/chat-history/go.mod +++ b/chat-history/go.mod @@ -3,14 +3,14 @@ module chat-history go 1.22.3 require ( - github.com/go-chi/chi/v5 v5.0.12 + github.com/go-chi/httplog/v2 v2.0.11 github.com/google/uuid v1.6.0 gorm.io/driver/sqlite v1.5.5 gorm.io/gorm v1.25.10 ) require ( - github.com/go-chi/httplog/v2 v2.0.11 // indirect + github.com/go-chi/chi/v5 v5.0.12 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/mattn/go-sqlite3 v1.14.22 // indirect diff --git a/chat-history/go.sum b/chat-history/go.sum index cce54bb7..4c7b12d8 100644 --- a/chat-history/go.sum +++ b/chat-history/go.sum @@ -1,3 +1,5 @@ +github.com/GenericP3rson/TigerGo v0.0.4 h1:xI7d/cLJ6sRP4fzanInakARE0XGk1YAmvn5KrH1fwFU= +github.com/GenericP3rson/TigerGo v0.0.4/go.mod h1:PGpAFO9vNA7l34WSGYCtWb/eqVKHuIq1xqvizBlNhRM= github.com/go-chi/chi/v5 v5.0.12 h1:9euLV5sTrTNTRUU9POmDUvfxyj6LAABLUcEWO+JJb4s= github.com/go-chi/chi/v5 v5.0.12/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-chi/httplog/v2 v2.0.11 h1:eu6kYksMEJzBcOP+ba/iYudc0m5rv4VvBAzroJMkaY4= diff --git a/chat-history/main.go b/chat-history/main.go index e913bcef..1f82effd 100644 --- a/chat-history/main.go +++ b/chat-history/main.go @@ -12,12 +12,18 @@ import ( ) func main() { - configPath:= os.Getenv("CONFIG") - config, err := config.LoadConfig(configPath) + configPath := os.Getenv("CONFIG_FILES") + // Split the paths into a slice + configPaths := strings.Split(configPath, ",") + + cfg, err := config.LoadConfig(map[string]string{ + "chatdb": configPaths[0], + "tgdb": configPaths[1], + }) if err != nil { panic(err) } - db.InitDB(config.DbPath, config.DbLogPath) + db.InitDB(cfg.ChatDbConfig.DbPath, cfg.ChatDbConfig.DbLogPath) // make router router := http.NewServeMux() @@ -30,14 +36,15 @@ func main() { router.HandleFunc("GET /user/{userId}", routes.GetUserConversations) router.HandleFunc("GET /conversation/{conversationId}", routes.GetConversation) router.HandleFunc("POST /conversation", routes.UpdateConversation) + router.HandleFunc("GET /get_feedback", routes.GetFeedback(cfg.TgDbConfig.Hostname, cfg.TgDbConfig.GsPort, cfg.ChatDbConfig.ConversationAccessRoles, cfg.TgDbConfig.TgCloud)) // create server with middleware dev := strings.ToLower(os.Getenv("DEV")) == "true" var port string if dev { - port = fmt.Sprintf("localhost:%s", config.Port) + port = fmt.Sprintf("localhost:%s", cfg.ChatDbConfig.Port) } else { - port = fmt.Sprintf(":%s", config.Port) + port = fmt.Sprintf(":%s", cfg.ChatDbConfig.Port) } handler := middleware.ChainMiddleware(router, diff --git a/chat-history/routes/routes.go b/chat-history/routes/routes.go index 076cb41a..0524f7a4 100644 --- a/chat-history/routes/routes.go +++ b/chat-history/routes/routes.go @@ -3,10 +3,12 @@ package routes import ( "chat-history/db" "chat-history/structs" + "encoding/base64" "encoding/json" "fmt" "io" "net/http" + "net/url" "slices" "strings" ) @@ -154,3 +156,152 @@ func auth(userId string, r *http.Request) (string, int, []byte, bool) { return usr, 0, nil, true } + +// executeGSQL sends a GSQL query to TigerGraph with basic authentication and returns the response +func executeGSQL(hostname, username, password, query, gsPort string, tgcloud bool) (string, error) { + var requestURL string + // Construct the URL for the GSQL query endpoint + if tgcloud { + requestURL = fmt.Sprintf("%s:443/gsqlserver/gsql/file", hostname) + } else { + requestURL = fmt.Sprintf("%s:%s/gsqlserver/gsql/file", hostname, gsPort) + } + // Prepare the query data + data := url.QueryEscape(query) // Encode query using URL encoding + reqBody := strings.NewReader(data) + + // Create the HTTP request + req, err := http.NewRequest("POST", requestURL, reqBody) + if err != nil { + return "", err + } + + // Set the required headers + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + // Set up basic authentication + auth := fmt.Sprintf("%s:%s", username, password) + req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth))) + + // Execute the request + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + // Read and return the response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + return string(body), nil +} + +// hasAdminAccess checks if the user's roles include any of the admin roles +func hasAdminAccess(userRoles []string, adminRoles []string) bool { + for _, role := range userRoles { + for _, adminRole := range adminRoles { + if role == adminRole { + return true + } + } + } + return false +} + +// parseUserRoles extracts roles from the user information string +func parseUserRoles(userInfo string, userName string) []string { + lines := strings.Split(userInfo, "\n") + var roles []string + var isUserSection bool + + for _, line := range lines { + if strings.Contains(line, "Name:") { + isUserSection = strings.Contains(line, userName) + } + if isUserSection && strings.Contains(line, "- Global Roles:") { + parts := strings.Split(line, ":") + if len(parts) > 1 { + roles = append(roles, strings.Split(strings.TrimSpace(parts[1]), ", ")...) + } + } + } + + return roles +} + +// GetFeedback retrieves feedback data for conversations +// "Get /get_feedback" +func GetFeedback(hostname, gsPort string, conversationAccessRoles []string, tgCloud bool) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + usr, pass, ok := r.BasicAuth() + if !ok { + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"reason":"missing Authorization header"}`)) + return + } + + // Verify if the user has the required role + userInfo, err := executeGSQL(hostname, usr, pass, "SHOW USER", gsPort, tgCloud) + if err != nil { + reason := []byte(`{"reason":"failed to retrieve feedback data"}`) + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + w.Write(reason) + return + } + + // Parse and check roles + userRoles := parseUserRoles(userInfo, usr) + if !hasAdminAccess(userRoles, conversationAccessRoles) { + // Fetch chat history messages for this specific user + conversations := db.GetUserConversations(usr) + + var allMessages []structs.Message + + for _, convo := range conversations { + messages := db.GetUserConversationById(usr, convo.ConversationId.String()) + allMessages = append(allMessages, messages...) + } + // Marshal and write the response + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + response, err := json.Marshal(allMessages) + if err != nil { + reason := []byte(`{"reason":"failed to marshal messages"}`) + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + w.Write(reason) + return + } + w.Write(response) + return + } + + // If the user has admin access, fetch all messages + messages, err := db.GetAllMessages() + if err != nil { + reason := []byte(`{"reason":"failed to retrieve feedback data"}`) + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + w.Write(reason) + return + } + + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + response, err := json.Marshal(messages) + if err != nil { + reason := []byte(`{"reason":"failed to marshal messages"}`) + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + w.Write(reason) + return + } + w.Write(response) + } +} diff --git a/chat-history/routes/routes_test.go b/chat-history/routes/routes_test.go index 1411e7b0..075bfe49 100644 --- a/chat-history/routes/routes_test.go +++ b/chat-history/routes/routes_test.go @@ -2,6 +2,7 @@ package routes import ( "bytes" + "chat-history/config" "chat-history/db" "chat-history/structs" "encoding/base64" @@ -12,6 +13,7 @@ import ( "net/http/httptest" "os" "slices" + "strings" "testing" "github.com/google/uuid" @@ -19,7 +21,7 @@ import ( const ( USER = "sam_pull" - PASS = "pass" + PASS = "sam_pull" CONVO_ID = "601529eb-4927-4e24-b285-bd6b9519a951" ) @@ -395,6 +397,170 @@ func messageEquals(m, msg structs.Message) bool { return false } +func TestExecuteGSQL(t *testing.T) { + + os.Setenv("CONFIG_FILES", "../chat_config.json,../db_config.json") + + configPath := os.Getenv("CONFIG_FILES") + // Split the paths into a slice + configPaths := strings.Split(configPath, ",") + + cfg, err := config.LoadConfig(map[string]string{ + "chatdb": configPaths[0], + "tgdb": configPaths[1], + }) + if err != nil { + panic(err) + } + query := "SHOW USER" + + response, err := executeGSQL(cfg.TgDbConfig.Hostname, cfg.TgDbConfig.Username, cfg.TgDbConfig.Password, query, cfg.TgDbConfig.GsPort, cfg.TgDbConfig.TgCloud) + if err != nil { + t.Fatalf("Failed to execute GSQL query: %v", err) + } + + // Check for common errors or issues in the response + if strings.Contains(response, "400 Bad Request") { + t.Error("Received '400 Bad Request' error. Please check the query and server configuration.") + } + + if strings.Contains(response, "401 Unauthorized") { + t.Error("Received '401 Unauthorized' error. Please check the credentials and access permissions.") + } + + if strings.Contains(response, "403 Forbidden") { + t.Error("Received '403 Forbidden' error. The user may not have sufficient permissions to execute the query.") + } + + if strings.Contains(response, "500 Internal Server Error") { + t.Error("Received '500 Internal Server Error'. This indicates a server-side issue.") + } + + // Add any additional checks on the response + if response == "" { + t.Error("Received empty response from GSQL query") + } + + // Check if the response contains "Name" and "Global Roles" + if !strings.Contains(response, "Name") { + t.Error("Response does not contain 'Name'.") + } + + if !strings.Contains(response, "Global Roles") { + t.Error("Response does not contain 'Global Roles'.") + } +} + +func TestParseUserRoles(t *testing.T) { + userInfo := ` + - Name: feedbackauthtest + - Global Roles: globalobserver + - Graph 'EarningsCallRAG' Roles: queryreader + - Graph 'Transaction_Fraud' Roles: designer, queryreader + - Graph 'pyTigerGraphRAG' Roles: queryreader, querywriter + - Secret: ad9****v7p + - Alias: AUTO_GENERATED_ALIAS_suv6mm5 + - GraphName: Transaction_Fraud + - LastSuccessLogin: Mon Jul 22 06:57:29 UTC 2024 + - NextValidLogin: Mon Jul 22 06:57:29 UTC 2024 + - FailedAttempts: 0 + - ShowAlterPasswordWarning: false + + - Name: Lu Zhou + - Global Roles: globalobserver + - LastSuccessLogin: Tue Jul 23 16:35:45 UTC 2024 + - NextValidLogin: Tue Jul 23 16:35:45 UTC 2024 + - FailedAttempts: 0 + - ShowAlterPasswordWarning: false + ` + + expectedRoles := []string{"globalobserver"} + + roles := parseUserRoles(userInfo, "feedbackauthtest") + + fmt.Println("Extracted Roles:", roles) + + if len(roles) != len(expectedRoles) { + t.Fatalf("expected %d roles, got %d", len(expectedRoles), len(roles)) + } + + for i, role := range expectedRoles { + if roles[i] != role { + t.Errorf("expected role %s, got %s", role, roles[i]) + } + } +} + +func TestGetFeedback(t *testing.T) { + + os.Setenv("CONFIG_FILES", "../chat_config.json,../db_config.json") + + setupDB(t, true) + + configPath := os.Getenv("CONFIG_FILES") + // Split the paths into a slice + configPaths := strings.Split(configPath, ",") + + cfg, err := config.LoadConfig(map[string]string{ + "chatdb": configPaths[0], + "tgdb": configPaths[1], + }) + if err != nil { + panic(err) + } + testFeedback := func(t *testing.T, username, password string, expectedStatus int, expectedMessagesCount int, expectedFirstMessageContent string) { + // Create a request with Basic Auth + req, err := http.NewRequest("GET", "/get_feedback", nil) + if err != nil { + t.Fatal(err) + } + req.SetBasicAuth(username, password) + + // Record the response + rr := httptest.NewRecorder() + handler := http.HandlerFunc(GetFeedback(cfg.TgDbConfig.Hostname, cfg.TgDbConfig.GsPort, cfg.ChatDbConfig.ConversationAccessRoles, cfg.TgDbConfig.TgCloud)) + + // Serve the request + handler.ServeHTTP(rr, req) + + // Check the response status code + if status := rr.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK) + } + if expectedStatus == http.StatusOK { + // Check the response body for expected messages + var messages []structs.Message + if err := json.Unmarshal(rr.Body.Bytes(), &messages); err != nil { + t.Errorf("Failed to parse response body: %v", err) + } + + // Print the messages for debugging + // fmt.Println("Retrieved messages:", messages) + // Validate that the messages are as expected + // expectedMessagesCount := 2 // Based on populateDB function + if len(messages) != expectedMessagesCount { + t.Errorf("Expected %d messages, got %d", expectedMessagesCount, len(messages)) + } + + // Additional checks to ensure the response contains the correct data + if expectedMessagesCount > 0 && len(messages) > 0 { + if messages[0].Content != expectedFirstMessageContent { + t.Errorf("Unexpected message content: %v", messages[0].Content) + } + } + } + } + + // Test case for admin user + testFeedback(t, "supportai", "supportai", http.StatusOK, 3, "This is the first message, there is no parent") + + // Test case for non-admin user + testFeedback(t, "sam_pull", "sam_pull", http.StatusOK, 2, "This is the first message, there is no parent") + + // Test case for non-existent user + testFeedback(t, "nonexistentuser", "password", http.StatusUnauthorized, 0, "") +} + // helpers func basicAuthSetup(user, pass string) string { return base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", user, pass))) diff --git a/chat-history/structs/structs.go b/chat-history/structs/structs.go index 8a34249e..9f5a3300 100644 --- a/chat-history/structs/structs.go +++ b/chat-history/structs/structs.go @@ -52,7 +52,7 @@ type Message struct { Content string `json:"content"` Role MessagengerRole `json:"role"` ResponseTime float64 `json:"response_time"` - Feedback Feedback `json:"feedback"`// time in fractional seconds (i.e., 1.25 seconds) + Feedback Feedback `json:"feedback"` // time in fractional seconds (i.e., 1.25 seconds) Comment string `json:"comment"` } diff --git a/copilot/app/routers/ui.py b/copilot/app/routers/ui.py index 12f4db42..9707976f 100644 --- a/copilot/app/routers/ui.py +++ b/copilot/app/routers/ui.py @@ -168,6 +168,31 @@ async def get_conversation_contents( return res.json() +@router.get(route_prefix + "/get_feedback") +async def get_conversation_feedback( + creds: Annotated[tuple[list[str], HTTPBasicCredentials], Depends(ui_basic_auth)], +): + creds = creds[1] + auth = base64.b64encode(f"{creds.username}:{creds.password}".encode()).decode() + try: + async with httpx.AsyncClient() as client: + res = await client.get( + f"{db_config['chat_history_api']}/get_feedback", + headers={"Authorization": f"Basic {auth}"}, + ) + res.raise_for_status() + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error occurred: {e}") + raise HTTPException(status_code=e.response.status_code, detail="Failed to fetch feedback") + except Exception as e: + exc = traceback.format_exc() + logger.debug_pii( + f"/get_feedback request_id={req_id_cv.get()} Exception Trace:\n{exc}" + ) + raise HTTPException(status_code=500, detail="Internal server error") + + return res.json() + async def emit_progress(agent: TigerGraphAgent, ws: WebSocket): # loop on q until done token emit events through ws diff --git a/copilot/docs/notebooks/FeedbackAnalysis.ipynb b/copilot/docs/notebooks/FeedbackAnalysis.ipynb index 447882c8..8967fab9 100644 --- a/copilot/docs/notebooks/FeedbackAnalysis.ipynb +++ b/copilot/docs/notebooks/FeedbackAnalysis.ipynb @@ -9,54 +9,32 @@ "import requests\n", "import base64\n", "\n", - "def create_headers(username, password):\n", + "def fetch_conversation_feedback(username, password): \n", " \"\"\"Create headers with Base64 encoded credentials.\"\"\"\n", " credentials = f\"{username}:{password}\"\n", " encoded_credentials = base64.b64encode(credentials.encode(\"utf-8\")).decode(\"utf-8\")\n", - " return {\n", + " headers = {\n", " 'accept': 'application/json',\n", " 'Authorization': f'Basic {encoded_credentials}'\n", - " }" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "def get_user_conversation_ids(username, password):\n", - " \"\"\"Fetch conversation IDs for a given user.\"\"\"\n", - " headers = create_headers(username, password)\n", - " user_url = f'http://COPILOT_ADDRESS/ui/user/{username}'\n", - " \n", - " response = requests.get(user_url, headers=headers)\n", - " \n", - " if response.status_code == 200:\n", - " data = response.json()\n", - " return [item['conversation_id'] for item in data]\n", - " else:\n", - " print(f\"Request failed with status code {response.status_code}\")\n", - " print(response.text)\n", - " return None" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "def get_conversation_data(username, password, conversation_id):\n", - " \"\"\"Fetch conversation data for a given conversation ID.\"\"\"\n", - " headers = create_headers(username, password)\n", - " conversation_url = f'http://COPILOT_ADDRESS/ui/conversation/{conversation_id}'\n", - " \n", - " response = requests.get(conversation_url, headers=headers)\n", - " \n", - " if response.status_code == 200:\n", - " data = response.json()\n", + " }\n", "\n", + " \"\"\"Fetch conversation and feedback data.\"\"\"\n", + " feedback_url = f'http://COPILOT_HOST/ui/get_feedback'\n", + " \n", + " try:\n", + " response = requests.get(feedback_url, headers=headers)\n", + " response.raise_for_status()\n", + "\n", + " try:\n", + " data = response.json()\n", + " except ValueError:\n", + " print(\"Error decoding JSON from response.\")\n", + " return None\n", + " \n", + " if not data:\n", + " print(\"No data received.\")\n", + " return None\n", + " \n", " # Create dictionaries to hold user questions and system answers\n", " questions = {message[\"message_id\"]: message for message in data if message[\"role\"] == \"user\"}\n", " answers = {message[\"parent_id\"]: message for message in data if message[\"role\"] == \"system\"}\n", @@ -73,50 +51,18 @@ " })\n", " \n", " return qa_pairs\n", - " else:\n", - " print(f\"Request failed with status code {response.status_code}\")\n", - " print(response.text)\n", + " except requests.RequestException as e:\n", + " print(f\"Request failed: {e}\")\n", " return None" ] }, { "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "def fetch_user_conversations(username, password, conversation_id=None):\n", - " if conversation_id:\n", - " conversations = {}\n", - " # Fetch a specific conversation\n", - " data = get_conversation_data(username, password, conversation_id)\n", - " \n", - " if data:\n", - " conversations[conversation_id] = data\n", - " return conversations\n", - " else:\n", - " return \"Conversation not found or could not be retrieved.\"\n", - " \n", - " else:\n", - " # Fetch all conversations\n", - " conversation_ids = get_user_conversation_ids(username, password)\n", - " conversations = {}\n", - " \n", - " for conv_id in conversation_ids:\n", - " data = get_conversation_data(username, password, conv_id)\n", - " if data:\n", - " conversations[conv_id] = data\n", - " \n", - " return conversations" - ] - }, - { - "cell_type": "code", - "execution_count": 5, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ - "conversation_data = fetch_user_conversations(\"YOUR_DB_USERNAME\", \"YOUR_DB_PASSWORD\")" + "conversation_data = fetch_conversation_feedback(\"DB_USERNAME\", \"DB_PASSWORD\")" ] }, { diff --git a/copilot/requirements.txt b/copilot/requirements.txt index d45f2a60..7a8bd83f 100644 --- a/copilot/requirements.txt +++ b/copilot/requirements.txt @@ -152,4 +152,4 @@ wandb==0.15.12 watchfiles==0.20.0 websockets==11.0.3 yarl==1.9.2 -zipp==3.19.2 +zipp==3.19.2 \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 6aa8c01f..2d03dcbe 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -45,7 +45,7 @@ services: ports: - 8002:8002 environment: - CONFIG: "/configs/chat_config.json" + CONFIG_FILES: "/configs/chat_config.json,/configs/db_config.json" LOGLEVEL: "INFO" volumes: - ./configs/:/configs