summaryrefslogtreecommitdiff
path: root/database.go
diff options
context:
space:
mode:
Diffstat (limited to 'database.go')
-rw-r--r--database.go112
1 files changed, 112 insertions, 0 deletions
diff --git a/database.go b/database.go
new file mode 100644
index 0000000..f7cdf9b
--- /dev/null
+++ b/database.go
@@ -0,0 +1,112 @@
+package main
+
+import (
+ "database/sql"
+ "log"
+
+ "golang.org/x/crypto/bcrypt"
+
+ "stevenlr.com/timer/model"
+)
+
+func initializeDatabaseV1(db *sql.DB) error {
+ tx, err := db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ _, err = tx.Exec(`PRAGMA user_version = 1`)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.Exec(`
+ CREATE TABLE Timer (
+ Id BLOB NOT NULL UNIQUE,
+ Name TEXT NOT NULL,
+ StartTime TEXT NOT NULL,
+ EndTime TEXT NOT NULL,
+ Owner BLOB NOT NULL,
+ Token TEXT NOT NULL UNIQUE,
+ PRIMARY KEY (Id)
+ )`)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.Exec(`
+ CREATE TABLE User (
+ Id BLOB NOT NULL UNIQUE,
+ Name TEXT NOT NULL,
+ Salt TEXT NOT NULL,
+ Password BLOB NOT NULL,
+ PRIMARY KEY (id)
+ )`)
+ if err != nil {
+ return err
+ }
+
+ userName := "admin"
+ userPassword := "admin"
+ salt, err := GenerateRandomString(33)
+ if err != nil {
+ return err
+ }
+
+ password, err := bcrypt.GenerateFromPassword([]byte(salt+userPassword), bcrypt.MinCost)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.Exec(`INSERT INTO User VALUES ($1, $2, $3, $4)`, model.MakeUUID(), userName, salt, password)
+ if err != nil {
+ return err
+ }
+
+ return tx.Commit()
+}
+
+func migrateDatabaseV2(db *sql.DB) error {
+ tx, err := db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ _, err = tx.Exec(`PRAGMA user_version = 2`)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.Exec("CREATE INDEX TimerTokenIndex ON Timer(Token)")
+ if err != nil {
+ return err
+ }
+
+ return tx.Commit()
+}
+
+func InitializeDatabase(db *sql.DB) error {
+ initialVersion := 0
+ row := db.QueryRow("PRAGMA user_version")
+ row.Scan(&initialVersion)
+
+ if initialVersion < 1 {
+ log.Println("Initializing DB V1")
+ err := initializeDatabaseV1(db)
+ if err != nil {
+ return err
+ }
+ }
+
+ if initialVersion < 2 {
+ log.Println("Migrating DB to V2")
+ err := migrateDatabaseV2(db)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}