diff --git a/internal/db_impl/sqlite3/client.go b/internal/db_impl/sqlite3/client.go index 2819d0c3..880e4a31 100644 --- a/internal/db_impl/sqlite3/client.go +++ b/internal/db_impl/sqlite3/client.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io/fs" + "net/url" "os" "path/filepath" "sync" @@ -287,7 +288,9 @@ func pathExists(path string) (bool, error) { } func getDatabaseConn(dir, userID, path string) string { - return fmt.Sprintf("file:%v?cache=shared&_fk=1&_journal=WAL", path) + // We need to escape special characters in the db path, such as # + escapedPath := url.PathEscape(path) + return fmt.Sprintf("file:%v?cache=shared&_fk=1&_journal=WAL", escapedPath) } func TestUpdateDBVersion(ctx context.Context, dbPath, userID string, version int) error { diff --git a/internal/db_impl/sqlite3/client_test.go b/internal/db_impl/sqlite3/client_test.go new file mode 100644 index 00000000..3bf14afc --- /dev/null +++ b/internal/db_impl/sqlite3/client_test.go @@ -0,0 +1,49 @@ +package sqlite3 + +import ( + "database/sql" + "fmt" + "os" + "path/filepath" + "testing" +) + +func TestClient_DBConnectionSpecialCharacterPath(t *testing.T) { + dbDirs := []string{ + "#test", + "test_test", + "test#test#test", + "test$test$test", + } + + testingDir := t.TempDir() + + for _, dirName := range dbDirs { + path := filepath.Join(testingDir, dirName) + if err := os.MkdirAll(path, 0777); err != nil { + fmt.Println("Could not create testing directory, error: ", err) + t.FailNow() + } + + filePath := filepath.Join(path, "test.db") + + client, err := sql.Open("sqlite3", getDatabaseConn("test", "test", filePath)) + if err != nil { + fmt.Println("Could not connect to test database, error: ", err) + t.FailNow() + } + + if err := client.Ping(); err != nil { + fmt.Println("Could not ping test database, error: ", err) + if closeErr := client.Close(); closeErr != nil { + fmt.Println("Could not close test database, error: ", closeErr) + } + t.FailNow() + } + + if err := client.Close(); err != nil { + fmt.Println("Could not close test database, error: ", err) + t.FailNow() + } + } +}