diff options
-rw-r--r-- | database.go | 112 | ||||
-rw-r--r-- | htmx.go | 9 | ||||
-rw-r--r-- | model/user.go | 28 | ||||
-rw-r--r-- | timer.db | bin | 28672 -> 28672 bytes | |||
-rw-r--r-- | timer.go | 153 | ||||
-rw-r--r-- | utils.go | 15 |
6 files changed, 172 insertions, 145 deletions
diff --git a/database.go b/database.go new file mode 100644 index 0000000..f7cdf9b --- /dev/null +++ b/database.go @@ -0,0 +1,112 @@ +package main + +import ( + "database/sql" + "log" + + "golang.org/x/crypto/bcrypt" + + "stevenlr.com/timer/model" +) + +func initializeDatabaseV1(db *sql.DB) error { + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.Exec(`PRAGMA user_version = 1`) + if err != nil { + return err + } + + _, err = tx.Exec(` + CREATE TABLE Timer ( + Id BLOB NOT NULL UNIQUE, + Name TEXT NOT NULL, + StartTime TEXT NOT NULL, + EndTime TEXT NOT NULL, + Owner BLOB NOT NULL, + Token TEXT NOT NULL UNIQUE, + PRIMARY KEY (Id) + )`) + if err != nil { + return err + } + + _, err = tx.Exec(` + CREATE TABLE User ( + Id BLOB NOT NULL UNIQUE, + Name TEXT NOT NULL, + Salt TEXT NOT NULL, + Password BLOB NOT NULL, + PRIMARY KEY (id) + )`) + if err != nil { + return err + } + + userName := "admin" + userPassword := "admin" + salt, err := GenerateRandomString(33) + if err != nil { + return err + } + + password, err := bcrypt.GenerateFromPassword([]byte(salt+userPassword), bcrypt.MinCost) + if err != nil { + return err + } + + _, err = tx.Exec(`INSERT INTO User VALUES ($1, $2, $3, $4)`, model.MakeUUID(), userName, salt, password) + if err != nil { + return err + } + + return tx.Commit() +} + +func migrateDatabaseV2(db *sql.DB) error { + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.Exec(`PRAGMA user_version = 2`) + if err != nil { + return err + } + + _, err = tx.Exec("CREATE INDEX TimerTokenIndex ON Timer(Token)") + if err != nil { + return err + } + + return tx.Commit() +} + +func InitializeDatabase(db *sql.DB) error { + initialVersion := 0 + row := db.QueryRow("PRAGMA user_version") + row.Scan(&initialVersion) + + if initialVersion < 1 { + log.Println("Initializing DB V1") + err := initializeDatabaseV1(db) + if err != nil { + return err + } + } + + if initialVersion < 2 { + log.Println("Migrating DB to V2") + err := migrateDatabaseV2(db) + if err != nil { + return err + } + } + + return nil +} @@ -0,0 +1,9 @@ +package main + +import ( + "net/http" +) + +func HtmxRedirect(w http.ResponseWriter, url string) { + w.Header().Add("HX-Redirect", "/") +} diff --git a/model/user.go b/model/user.go index 4959371..09562bd 100644 --- a/model/user.go +++ b/model/user.go @@ -1,8 +1,36 @@ package model +import ( + "database/sql" +) + type User struct { Id UUID Name string Salt string Password []byte } + +func GetUserByName(db *sql.DB, name string) *User { + row := db.QueryRow("SELECT Id, Name, Salt, Password FROM User WHERE Name=$1", name) + if row == nil { + return nil + } + + var user User + row.Scan(&user.Id, &user.Name, &user.Salt, &user.Password) + + return &user +} + +func GetUserById(db *sql.DB, id UUID) *User { + row := db.QueryRow("SELECT Id, Name, Salt, Password FROM User WHERE Id=$1", id) + if row == nil { + return nil + } + + var user User + row.Scan(&user.Id, &user.Name, &user.Salt, &user.Password) + + return &user +} Binary files differ@@ -2,9 +2,7 @@ package main import ( "context" - "crypto/rand" "database/sql" - "encoding/base64" "errors" "fmt" "log" @@ -21,21 +19,12 @@ import ( "stevenlr.com/timer/view" ) -func generateRandomString(len int) (string, error) { - bin := make([]byte, len) - _, err := rand.Read(bin) - if err != nil { - return "", err - } - return base64.StdEncoding.EncodeToString(bin), nil -} - func generateSessionId() (string, error) { - return generateRandomString(66) + return GenerateRandomString(66) } func generateTimerToken() (string, error) { - return generateRandomString(66) + return GenerateRandomString(66) } func insertTimer(tx *sql.Tx, name string, seconds int, ownerId model.UUID) error { @@ -48,108 +37,6 @@ func insertTimer(tx *sql.Tx, name string, seconds int, ownerId model.UUID) error return err } -func initializeDatabaseV1(db *sql.DB) error { - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - _, err = tx.Exec(`PRAGMA user_version = 1`) - if err != nil { - return err - } - - _, err = tx.Exec(` - CREATE TABLE Timer ( - Id BLOB NOT NULL UNIQUE, - Name TEXT NOT NULL, - StartTime TEXT NOT NULL, - EndTime TEXT NOT NULL, - Owner BLOB NOT NULL, - Token TEXT NOT NULL UNIQUE, - PRIMARY KEY (Id) - )`) - if err != nil { - return err - } - - _, err = tx.Exec(` - CREATE TABLE User ( - Id BLOB NOT NULL UNIQUE, - Name TEXT NOT NULL, - Salt TEXT NOT NULL, - Password BLOB NOT NULL, - PRIMARY KEY (id) - )`) - if err != nil { - return err - } - - userName := "admin" - userPassword := "admin" - salt, err := generateRandomString(33) - if err != nil { - return err - } - - password, err := bcrypt.GenerateFromPassword([]byte(salt+userPassword), bcrypt.MinCost) - if err != nil { - return err - } - - _, err = tx.Exec(`INSERT INTO User VALUES ($1, $2, $3, $4)`, model.MakeUUID(), userName, salt, password) - if err != nil { - return err - } - - return tx.Commit() -} - -func migrateDatabaseV2(db *sql.DB) error { - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - _, err = tx.Exec(`PRAGMA user_version = 2`) - if err != nil { - return err - } - - _, err = tx.Exec("CREATE INDEX TimerTokenIndex ON Timer(Token)") - if err != nil { - return err - } - - return tx.Commit() -} - -func initializeDatabase(db *sql.DB) error { - initialVersion := 0 - row := db.QueryRow("PRAGMA user_version") - row.Scan(&initialVersion) - - if initialVersion < 1 { - log.Println("Initializing DB V1") - err := initializeDatabaseV1(db) - if err != nil { - return err - } - } - - if initialVersion < 2 { - log.Println("Migrating DB to V2") - err := migrateDatabaseV2(db) - if err != nil { - return err - } - } - - return nil -} - func queryAllTimers(db *sql.DB, owner model.UUID) []model.Timer { rows, err := db.Query("SELECT Id, Name FROM Timer WHERE Owner=$1", owner) if err != nil { @@ -167,30 +54,6 @@ func queryAllTimers(db *sql.DB, owner model.UUID) []model.Timer { return timers } -func queryUserByName(db *sql.DB, name string) *model.User { - row := db.QueryRow("SELECT Id, Name, Salt, Password FROM User WHERE Name=$1", name) - if row == nil { - return nil - } - - var user model.User - row.Scan(&user.Id, &user.Name, &user.Salt, &user.Password) - - return &user -} - -func queryUserById(db *sql.DB, id model.UUID) *model.User { - row := db.QueryRow("SELECT Id, Name, Salt, Password FROM User WHERE Id=$1", id) - if row == nil { - return nil - } - - var user model.User - row.Scan(&user.Id, &user.Name, &user.Salt, &user.Password) - - return &user -} - func queryTimerFromUser(db *sql.DB, idStr string, userId model.UUID) *model.Timer { var id model.UUID if err := id.Scan(idStr); err != nil { @@ -290,7 +153,7 @@ func (server *MyServer) findCurrentUser(w http.ResponseWriter, r *http.Request) return nil } - user := queryUserById(server.db, userId.UserId) + user := model.GetUserById(server.db, userId.UserId) if user == nil { removeCookie(SessionCookieName, w) } @@ -537,14 +400,14 @@ func (server *MyServer) handleCreateTimer(w http.ResponseWriter, r *http.Request func (server *MyServer) handlePostLogin(w http.ResponseWriter, r *http.Request) { if server.findCurrentUser(w, r) != nil { - w.Header().Add("HX-Redirect", "/") + HtmxRedirect(w, "/") return } userName := r.FormValue("user") userPass := r.FormValue("password") - user := queryUserByName(server.db, userName) + user := model.GetUserByName(server.db, userName) if user == nil { w.WriteHeader(http.StatusBadRequest) view.LoginFormError(nil, "Incorrect credentials").Render(context.Background(), w) @@ -573,7 +436,7 @@ func (server *MyServer) handlePostLogin(w http.ResponseWriter, r *http.Request) } server.sessions[sessionId] = Session{UserId: user.Id} http.SetCookie(w, &cookie) - w.Header().Add("HX-Redirect", "/") + HtmxRedirect(w, "/") } func (server *MyServer) handlePostLogout(w http.ResponseWriter, r *http.Request) { @@ -581,7 +444,7 @@ func (server *MyServer) handlePostLogout(w http.ResponseWriter, r *http.Request) delete(server.sessions, cookie.Value) removeCookie(SessionCookieName, w) } - w.Header().Add("HX-Redirect", "/") + HtmxRedirect(w, "/") } func main() { @@ -593,7 +456,7 @@ func main() { } defer db.Close() - if err := initializeDatabase(db); err != nil { + if err := InitializeDatabase(db); err != nil { log.Fatalln(err) } diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..5aa2894 --- /dev/null +++ b/utils.go @@ -0,0 +1,15 @@ +package main + +import ( + "crypto/rand" + "encoding/base64" +) + +func GenerateRandomString(len int) (string, error) { + bin := make([]byte, len) + _, err := rand.Read(bin) + if err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(bin), nil +} |