Skip to content

Commit

Permalink
feat: Auth support
Browse files Browse the repository at this point in the history
- Added support for Basic and API token

Waiting on amikos-tech/chromadb-chart#39 to be implemented to add the integration tests.

Refs: #2
  • Loading branch information
tazarov committed Jan 8, 2024
1 parent a873a3b commit 5695e3a
Show file tree
Hide file tree
Showing 2 changed files with 226 additions and 3 deletions.
122 changes: 120 additions & 2 deletions chroma.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,129 @@ type Client struct {
ApiClient *openapiclient.APIClient //nolint
}

func NewClient(basePath string) *Client {
type AuthType string

const (
BASIC AuthType = "basic"
TokenAuthorization AuthType = "authorization"
TokenXChromaToken AuthType = "xchromatoken"
)

type AuthMethod interface {
GetCredentials() map[string]string
GetType() AuthType
}

type BasicAuth struct {
Username string
Password string
}

func (b BasicAuth) GetCredentials() map[string]string {
return map[string]string{
"username": b.Username,
"password": b.Password,
}
}

func (b BasicAuth) GetType() AuthType {
return BASIC
}

func NewBasicAuth(username string, password string) ClientAuthCredentials {
return ClientAuthCredentials{
AuthMethod: BasicAuth{
Username: username,
Password: password,
},
}
}

type AuthorizationTokenAuth struct {
Token string
}

func (t AuthorizationTokenAuth) GetType() AuthType {
return TokenAuthorization
}

func (t AuthorizationTokenAuth) GetCredentials() map[string]string {
return map[string]string{
"Authorization": "Bearer " + t.Token,
}
}

type XChromaTokenAuth struct {
Token string
}

func (t XChromaTokenAuth) GetType() AuthType {
return TokenXChromaToken
}

func (t XChromaTokenAuth) GetCredentials() map[string]string {
return map[string]string{
"X-Chroma-Token": t.Token,
}
}

type ClientAuthCredentials struct {
AuthMethod AuthMethod
}

func NewTokenAuth(token string, authType AuthType) ClientAuthCredentials {
switch {
case authType == TokenAuthorization:
return ClientAuthCredentials{
AuthMethod: AuthorizationTokenAuth{
Token: token,
},
}
case authType == TokenXChromaToken:
return ClientAuthCredentials{
AuthMethod: XChromaTokenAuth{
Token: token,
},
}
default:
panic("Invalid auth type")
}
}

type ClientConfig struct {
BasePath string
DefaultHeaders *map[string]string
ClientAuthCredentials *ClientAuthCredentials
}

func NewClientConfig(basePath string, defaultHeaders *map[string]string, clientAuthCredentials *ClientAuthCredentials) ClientConfig {
return ClientConfig{
BasePath: basePath,
DefaultHeaders: defaultHeaders,
ClientAuthCredentials: clientAuthCredentials,
}
}

func NewClient(config ClientConfig) *Client {
configuration := openapiclient.NewConfiguration()
if config.ClientAuthCredentials != nil {
// combine config.DefaultHeaders and config.AuthMethod.GetCredentials() maps
var headers = make(map[string]string)
if config.DefaultHeaders != nil {
for k, v := range *config.DefaultHeaders {
headers[k] = v
}
}
for k, v := range config.ClientAuthCredentials.AuthMethod.GetCredentials() {
headers[k] = v
}
configuration.DefaultHeader = headers
} else if config.DefaultHeaders != nil {
configuration.DefaultHeader = *config.DefaultHeaders
}
configuration.Servers = openapiclient.ServerConfigurations{
{
URL: basePath,
URL: config.BasePath,
Description: "No description provided",
},
}
Expand Down
107 changes: 106 additions & 1 deletion test/chroma_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ func Test_chroma_client(t *testing.T) {
if chromaURL == "" {
chromaURL = "http://localhost:8000"
}
client := chroma.NewClient(chromaURL)

clientConfig := chroma.NewClientConfig(chromaURL, nil, nil)
client := chroma.NewClient(clientConfig)

t.Run("Test Heartbeat", func(t *testing.T) {
resp, err := client.Heartbeat()
Expand Down Expand Up @@ -746,3 +748,106 @@ func Test_chroma_client(t *testing.T) {
require.Nil(t, addError)
})
}

func Test_chroma_client_with_basic(t *testing.T) {
chromaURL := os.Getenv("CHROMA_URL")
if chromaURL == "" {
chromaURL = "http://localhost:8003"
}
clientAuth := chroma.NewBasicAuth("test", "test")

clientConfig := chroma.NewClientConfig(chromaURL, nil, &clientAuth)
client := chroma.NewClient(clientConfig)

t.Run("Test Heartbeat", func(t *testing.T) {
resp, err := client.Heartbeat()

require.Nil(t, err)
require.NotNil(t, resp)
assert.Truef(t, resp["nanosecond heartbeat"] > 0, "Heartbeat should be greater than 0")
})
}

func Test_chroma_client_with_authorization_token(t *testing.T) {
chromaURL := os.Getenv("CHROMA_URL")
if chromaURL == "" {
chromaURL = "http://localhost:8001"
}
clientAuth := chroma.NewTokenAuth("test", chroma.TokenAuthorization)

clientConfig := chroma.NewClientConfig(chromaURL, nil, &clientAuth)
client := chroma.NewClient(clientConfig)

t.Run("Test List Collections", func(t *testing.T) {
collectionName1 := "test-collection1"
collectionName2 := "test-collection2"
metadata := map[string]string{}
apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" {
err := godotenv.Load("../.env")
if err != nil {
assert.Failf(t, "Error loading .env file", "%s", err)
}
apiKey = os.Getenv("OPENAI_API_KEY")
}
embeddingFunction := openai.NewOpenAIEmbeddingFunction(apiKey)
distanceFunction := chroma.L2
_, errRest := client.Reset()
if errRest != nil {
assert.Fail(t, fmt.Sprintf("Error resetting database: %s", errRest))
}
_, _ = client.CreateCollection(collectionName1, chroma.MapToAPI(metadata), true, embeddingFunction, distanceFunction)
_, _ = client.CreateCollection(collectionName2, chroma.MapToAPI(metadata), true, embeddingFunction, distanceFunction)
collections, gcerr := client.ListCollections()
require.Nil(t, gcerr)
assert.Equal(t, 2, len(collections))
names := make([]string, len(collections))
for i, person := range collections {
names[i] = person.Name
}
assert.Contains(t, names, collectionName1)
assert.Contains(t, names, collectionName2)
})
}

func Test_chroma_client_with_x_token(t *testing.T) {
chromaURL := os.Getenv("CHROMA_URL")
if chromaURL == "" {
chromaURL = "http://localhost:8002"
}
clientAuth := chroma.NewTokenAuth("test", chroma.TokenXChromaToken)

clientConfig := chroma.NewClientConfig(chromaURL, nil, &clientAuth)
client := chroma.NewClient(clientConfig)

t.Run("Test List Collections", func(t *testing.T) {
collectionName1 := "test-collection1"
collectionName2 := "test-collection2"
metadata := map[string]string{}
apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" {
err := godotenv.Load("../.env")
if err != nil {
assert.Failf(t, "Error loading .env file", "%s", err)
}
apiKey = os.Getenv("OPENAI_API_KEY")
}
embeddingFunction := openai.NewOpenAIEmbeddingFunction(apiKey)
distanceFunction := chroma.L2
_, errRest := client.Reset()
if errRest != nil {
assert.Fail(t, fmt.Sprintf("Error resetting database: %s", errRest))
}
_, _ = client.CreateCollection(collectionName1, chroma.MapToAPI(metadata), true, embeddingFunction, distanceFunction)
_, _ = client.CreateCollection(collectionName2, chroma.MapToAPI(metadata), true, embeddingFunction, distanceFunction)
collections, gcerr := client.ListCollections()
require.Nil(t, gcerr)
assert.Equal(t, 2, len(collections))
names := make([]string, len(collections))
for i, person := range collections {
names[i] = person.Name
}
assert.Contains(t, names, collectionName1)
assert.Contains(t, names, collectionName2)
})
}

0 comments on commit 5695e3a

Please sign in to comment.