Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gml 1809 feedback analysis access control #256

Merged
merged 15 commits into from
Aug 2, 2024
2 changes: 1 addition & 1 deletion chat-history/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ clean:

run: clean test build
clear
CONFIG="config.json" DEV=true ./chat-history
CONFIG_FILES="chat_config.json,db_config.json" DEV=true ./chat-history


79 changes: 55 additions & 24 deletions chat-history/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,83 @@ package config

import (
"encoding/json"
"fmt"
"os"
)

type LLMConfig struct {
ModelName string `json:"model_name"`
}

type DbConfig struct {
Port string `json:"apiPort"`
DbPath string `json:"dbPath"`
DbLogPath string `json:"dbLogPath"`
LogPath string `json:"logPath"`
// DbHostname string `json:"hostname"`
// Username string `json:"username"`
// Password string `json:"password"`
type ChatDbConfig struct {
Port string `json:"apiPort"`
DbPath string `json:"dbPath"`
DbLogPath string `json:"dbLogPath"`
LogPath string `json:"logPath"`
ConversationAccessRoles []string `json:"conversationAccessRoles"`
}

type TgDbConfig struct {
Hostname string `json:"hostname"`
Username string `json:"username"`
Password string `json:"password"`
GsPort string `json:"gsPort"`
TgCloud bool `json:"tgCloud"`
// GetToken string `json:"getToken"`
// DefaultTimeout string `json:"default_timeout"`
// DefaultMemThreshold string `json:"default_mem_threshold"`
// DefaultThreadLimit string `json:"default_thread_limit"`
}

type Config struct {
DbConfig
ChatDbConfig
TgDbConfig
// LLMConfig
}

func LoadConfig(path string) (Config, error) {
var b []byte
if _, err := os.Stat(path); os.IsNotExist(err) {
// file doesn't exist read from env
cfg := os.Getenv("CONFIG")
if cfg == "" {
fmt.Println("CONFIG path is not found nor is the CONFIG json env variable defined")
os.Exit(1)
func LoadConfig(paths map[string]string) (Config, error) {
var config Config

// Load database config
if dbConfigPath, ok := paths["chatdb"]; ok {
dbConfig, err := loadChatDbConfig(dbConfigPath)
if err != nil {
return Config{}, err
}
b = []byte(cfg)
} else {
b, err = os.ReadFile(path)
config.ChatDbConfig = dbConfig
}

// Load TigerGraph config
if tgConfigPath, ok := paths["tgdb"]; ok {
tgConfig, err := loadTgDbConfig(tgConfigPath)
if err != nil {
return Config{}, err
}
config.TgDbConfig = tgConfig
}

var cfg Config
json.Unmarshal(b, &cfg)
return config, nil
}

return cfg, nil
func loadChatDbConfig(path string) (ChatDbConfig, error) {
var dbConfig ChatDbConfig
b, err := os.ReadFile(path)
if err != nil {
return ChatDbConfig{}, err
}
if err := json.Unmarshal(b, &dbConfig); err != nil {
return ChatDbConfig{}, err
}
return dbConfig, nil
}

func loadTgDbConfig(path string) (TgDbConfig, error) {
var tgConfig TgDbConfig
b, err := os.ReadFile(path)
if err != nil {
return TgDbConfig{}, err
}
if err := json.Unmarshal(b, &tgConfig); err != nil {
return TgDbConfig{}, err
}
return tgConfig, nil
}
68 changes: 43 additions & 25 deletions chat-history/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,41 +7,59 @@ import (
)

func TestLoadConfig(t *testing.T) {
pth := setup(t)
cfg, err := LoadConfig(pth)
chatConfigPath, tgConfigPath := setup(t)

cfg, err := LoadConfig(map[string]string{
"chatdb": chatConfigPath,
"tgdb": tgConfigPath,
})
if err != nil {
t.Fatal(err)
}

if cfg.Port != "8000" ||
cfg.DbPath != "chats.db" ||
cfg.DbLogPath != "db.log" ||
cfg.LogPath != "requestLogs.jsonl" {
t.Fatalf("config is wrong, %v", cfg)
if cfg.ChatDbConfig.Port != "8002" ||
cfg.ChatDbConfig.DbPath != "chats.db" ||
cfg.ChatDbConfig.DbLogPath != "db.log" ||
cfg.ChatDbConfig.LogPath != "requestLogs.jsonl" {
t.Fatalf("config is wrong, %v", cfg.ChatDbConfig)
}

if cfg.TgDbConfig.Hostname != "https://tg-0cdef603-3760-41c3-af6f-41e95afc40de.us-east-1.i.tgcloud.io" ||
cfg.TgDbConfig.GsPort != "14240" ||
cfg.TgDbConfig.TgCloud != true {
t.Fatalf("TigerGraph config is wrong, %v", cfg.TgDbConfig)
}
}

func setup(t *testing.T) string {
func setup(t *testing.T) (string, string) {
tmp := t.TempDir()
pth := fmt.Sprintf("%s/%s", tmp, "config.json")
dat := `

chatConfigPath := fmt.Sprintf("%s/%s", tmp, "chat_config.json")
chatConfigData := `
{
"apiPort":"8000",
"hostname": "http://localhost:14240",
"dbPath": "chats.db",
"dbLogPath": "db.log",
"logPath": "requestLogs.jsonl",
"username": "tigergraph",
"password": "tigergraph",
"getToken": false,
"default_timeout": 300,
"default_mem_threshold": 5000,
"default_thread_limit": 8
"apiPort":"8002",
"dbPath": "chats.db",
"dbLogPath": "db.log",
"logPath": "requestLogs.jsonl",
"conversationAccessRoles": ["superuser", "globaldesigner"]
}`
err := os.WriteFile(pth, []byte(dat), 0644)
if err != nil {
t.Fatal("error setting up config.json")

if err := os.WriteFile(chatConfigPath, []byte(chatConfigData), 0644); err != nil {
t.Fatal("error setting up chat_config.json")
}
return pth

tgConfigPath := fmt.Sprintf("%s/%s", tmp, "db_config.json")
tgConfigData := `
{
"hostname": "https://tg-0cdef603-3760-41c3-af6f-41e95afc40de.us-east-1.i.tgcloud.io",
"gsPort": "14240",
"username": "supportai",
"password": "supportai",
"tgCloud": true
}`
if err := os.WriteFile(tgConfigPath, []byte(tgConfigData), 0644); err != nil {
t.Fatal("error setting up tg_config.json")
}

return chatConfigPath, tgConfigPath
}
31 changes: 28 additions & 3 deletions chat-history/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,28 @@ func UpdateConversationById(message structs.Message) (*structs.Conversation, err
return &convo, nil
}

// GetAllMessages retrieves all messages from the database
func GetAllMessages() ([]structs.Message, error) {
var messages []structs.Message

// Use GORM to query all messages
if err := db.Find(&messages).Error; err != nil {
return nil, err
}

return messages, nil
}

func populateDB() {
mu.Lock()
defer mu.Unlock()

// init convos
conv1 := uuid.MustParse("601529eb-4927-4e24-b285-bd6b9519a951")
conv2 := uuid.MustParse("601529eb-4927-4e24-b285-bd6b9519a952")
db.Create(&structs.Conversation{UserId: "sam_pull", ConversationId: conv1, Name: "conv1"})
db.Create(&structs.Conversation{UserId: "sam_pull", ConversationId: uuid.New(), Name: "conv2"})
db.Create(&structs.Conversation{UserId: "Miss_Take", ConversationId: uuid.New(), Name: "conv3"})
db.Create(&structs.Conversation{UserId: "Miss_Take", ConversationId: conv2, Name: "conv2"})
// db.Create(&structs.Conversation{UserId: "Miss_Take", ConversationId: uuid.New(), Name: "conv3"})

// add message to convos
message := structs.Message{
Expand All @@ -152,8 +165,8 @@ func populateDB() {
Feedback: structs.NoFeedback,
Comment: "",
}

db.Create(&message)

m2 := structs.Message{
ConversationId: conv1,
MessageId: uuid.New(),
Expand All @@ -165,4 +178,16 @@ func populateDB() {
Comment: "",
}
db.Create(&m2)

m3 := structs.Message{
ConversationId: conv2,
MessageId: uuid.New(),
ParentId: &message.MessageId,
ModelName: "GPT-4o",
Content: "How many transactions?",
Role: structs.SystemRole,
Feedback: structs.NoFeedback,
Comment: "",
}
db.Create(&m3)
}
23 changes: 23 additions & 0 deletions chat-history/db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,29 @@ func TestParallelWrites(t *testing.T) {
}
}

func TestGetAllMessages(t *testing.T) {
setupTest(t, true)

messages, err := GetAllMessages()
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}

// Ensure that messages are returned
if len(messages) == 0 {
t.Fatalf("Expected some messages, got none")
}

// Validate the structure of the messages
for _, m := range messages {
if uuid.Validate(m.ConversationId.String()) != nil ||
uuid.Validate(m.MessageId.String()) != nil ||
(m.Role != "system" && m.Role != "user") {
t.Fatalf("Invaid message structure: %v", m)
}
}
}

/*
helper functions
*/
Expand Down
4 changes: 2 additions & 2 deletions chat-history/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ module chat-history
go 1.22.3

require (
github.com/go-chi/chi/v5 v5.0.12
github.com/go-chi/httplog/v2 v2.0.11
github.com/google/uuid v1.6.0
gorm.io/driver/sqlite v1.5.5
gorm.io/gorm v1.25.10
)

require (
github.com/go-chi/httplog/v2 v2.0.11 // indirect
github.com/go-chi/chi/v5 v5.0.12 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/mattn/go-sqlite3 v1.14.22 // indirect
Expand Down
2 changes: 2 additions & 0 deletions chat-history/go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
github.com/GenericP3rson/TigerGo v0.0.4 h1:xI7d/cLJ6sRP4fzanInakARE0XGk1YAmvn5KrH1fwFU=
github.com/GenericP3rson/TigerGo v0.0.4/go.mod h1:PGpAFO9vNA7l34WSGYCtWb/eqVKHuIq1xqvizBlNhRM=
github.com/go-chi/chi/v5 v5.0.12 h1:9euLV5sTrTNTRUU9POmDUvfxyj6LAABLUcEWO+JJb4s=
github.com/go-chi/chi/v5 v5.0.12/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/go-chi/httplog/v2 v2.0.11 h1:eu6kYksMEJzBcOP+ba/iYudc0m5rv4VvBAzroJMkaY4=
Expand Down
17 changes: 12 additions & 5 deletions chat-history/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@ import (
)

func main() {
configPath:= os.Getenv("CONFIG")
config, err := config.LoadConfig(configPath)
configPath := os.Getenv("CONFIG_FILES")
// Split the paths into a slice
configPaths := strings.Split(configPath, ",")

cfg, err := config.LoadConfig(map[string]string{
"chatdb": configPaths[0],
"tgdb": configPaths[1],
})
if err != nil {
panic(err)
}
db.InitDB(config.DbPath, config.DbLogPath)
db.InitDB(cfg.ChatDbConfig.DbPath, cfg.ChatDbConfig.DbLogPath)

// make router
router := http.NewServeMux()
Expand All @@ -30,14 +36,15 @@ func main() {
router.HandleFunc("GET /user/{userId}", routes.GetUserConversations)
router.HandleFunc("GET /conversation/{conversationId}", routes.GetConversation)
router.HandleFunc("POST /conversation", routes.UpdateConversation)
router.HandleFunc("GET /get_feedback", routes.GetFeedback(cfg.TgDbConfig.Hostname, cfg.TgDbConfig.GsPort, cfg.ChatDbConfig.ConversationAccessRoles, cfg.TgDbConfig.TgCloud))

// create server with middleware
dev := strings.ToLower(os.Getenv("DEV")) == "true"
var port string
if dev {
port = fmt.Sprintf("localhost:%s", config.Port)
port = fmt.Sprintf("localhost:%s", cfg.ChatDbConfig.Port)
} else {
port = fmt.Sprintf(":%s", config.Port)
port = fmt.Sprintf(":%s", cfg.ChatDbConfig.Port)
}

handler := middleware.ChainMiddleware(router,
Expand Down
Loading
Loading