diff --git a/postgres.go b/postgres.go index d3d2e29..edcb781 100644 --- a/postgres.go +++ b/postgres.go @@ -2,6 +2,7 @@ package main import ( "database/sql" + "fmt" _ "github.com/lib/pq" @@ -50,9 +51,9 @@ CREATE TABLE IF NOT EXISTS users ( statement.Exec() statement, _ = db.Prepare("CREATE UNIQUE INDEX IF NOT EXISTS idx_user_name on users(name)") statement.Exec() - statement, _ = db.Prepare("CREATE TABLE IF NOT EXISTS groups (id SERIAL PRIMARY KEY, name TEXT NOT NULL, gidnumber INTEGER NOT NULL)") + statement, _ = db.Prepare("CREATE TABLE IF NOT EXISTS ldapgroups (id SERIAL PRIMARY KEY, name TEXT NOT NULL, gidnumber INTEGER NOT NULL)") statement.Exec() - statement, _ = db.Prepare("CREATE UNIQUE INDEX IF NOT EXISTS idx_group_name on groups(name)") + statement, _ = db.Prepare("CREATE UNIQUE INDEX IF NOT EXISTS idx_group_name on ldapgroups(name)") statement.Exec() statement, _ = db.Prepare("CREATE TABLE IF NOT EXISTS includegroups (id SERIAL PRIMARY KEY, parentgroupid INTEGER NOT NULL, includegroupid INTEGER NOT NULL)") statement.Exec() @@ -66,4 +67,24 @@ func (b PostgresBackend) MigrateSchema(db *sql.DB, checker func(*sql.DB, string) statement, _ := db.Prepare("ALTER TABLE users ADD COLUMN sshkeys TEXT DEFAULT ''") statement.Exec() } + + if TableExists(db, "groups") { + // Drop the table created during schema creation + statement, _ := db.Prepare("DROP TABLE ldapgroups") + statement.Exec() + + statement, _ = db.Prepare("ALTER TABLE groups RENAME TO ldapgroups") + statement.Exec() + } +} + +// Indicates whether the table exists or not +func TableExists(db *sql.DB, tableName string) bool { + var found string + err := db.QueryRow(fmt.Sprintf("SELECT COUNT(id) FROM %s", tableName)).Scan( + &found) + if err != nil { + return false + } + return true }