From 3d507d36ea2c7955d98a16b85cb7bc02c8923caa Mon Sep 17 00:00:00 2001 From: Steven Le Rouzic Date: Mon, 15 Apr 2024 23:17:09 +0200 Subject: User login & logout --- timer.go | 122 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 116 insertions(+), 6 deletions(-) (limited to 'timer.go') diff --git a/timer.go b/timer.go index e76f950..72d4094 100644 --- a/timer.go +++ b/timer.go @@ -12,6 +12,8 @@ import ( "golang.org/x/crypto/bcrypt" + "github.com/google/uuid" + _ "github.com/mattn/go-sqlite3" "stevenlr.com/timer/model" @@ -51,12 +53,12 @@ func initializeDatabase(db *sql.DB) error { return err } - err = insertTimer(tx, "My timer", 6) + err = insertTimer(tx, "My timer", 600) if err != nil { return err } - err = insertTimer(tx, "My timer2", 6) + err = insertTimer(tx, "My timer2", 600) if err != nil { return err } @@ -112,6 +114,30 @@ func queryAllTimers(db *sql.DB) []model.Timer { return timers } +func queryUserByName(db *sql.DB, name string) *model.User { + row := db.QueryRow("SELECT Id, Name, Salt, Password FROM User WHERE Name=$1", name) + if row == nil { + return nil + } + + var user model.User + row.Scan(&user.Id, &user.Name, &user.Salt, &user.Password) + + return &user +} + +func queryUserById(db *sql.DB, id model.UUID) *model.User { + row := db.QueryRow("SELECT Id, Name, Salt, Password FROM User WHERE Id=$1", id) + if row == nil { + return nil + } + + var user model.User + row.Scan(&user.Id, &user.Name, &user.Salt, &user.Password) + + return &user +} + func queryTimer(db *sql.DB, idStr string) *model.Timer { var id model.UUID if err := id.Scan(idStr); err != nil { @@ -154,7 +180,7 @@ func updateTimerEndTime(db *sql.DB, id model.UUID, endTime model.Time) bool { } type Session struct { - UserId []byte + UserId model.UUID } type MyServer struct { @@ -162,24 +188,57 @@ type MyServer struct { sessions map[string]Session } +const SessionCookieName = "timerSession" + +func removeCookie(cookieName string, w http.ResponseWriter) { + cookie := http.Cookie{ + Name: cookieName, + Value: "", + MaxAge: -1, + } + http.SetCookie(w, &cookie) +} + +func (server *MyServer) findCurrentUser(w http.ResponseWriter, r *http.Request) *model.User { + cookie, err := r.Cookie(SessionCookieName) + if err != nil { + return nil + } + + userId, ok := server.sessions[cookie.Value] + if !ok { + removeCookie(SessionCookieName, w) + return nil + } + + user := queryUserById(server.db, userId.UserId) + if user == nil { + removeCookie(SessionCookieName, w) + } + + return user +} + func (server *MyServer) handleNotFound(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusNotFound) - view.Main(view.Error404()).Render(context.Background(), w) + view.Error404().Render(context.Background(), w) } func (server *MyServer) handleMain(w http.ResponseWriter, r *http.Request) { + currentUser := server.findCurrentUser(w, r) if r.URL.Path == "/" { timers := queryAllTimers(server.db) - view.Main(view.TimersList(timers)).Render(context.Background(), w) + view.Main(view.TimersList(timers), currentUser).Render(context.Background(), w) } else { server.handleNotFound(w, r) } } func (server *MyServer) handleTimer(w http.ResponseWriter, r *http.Request) { + currentUser := server.findCurrentUser(w, r) timer := queryTimer(server.db, r.PathValue("timerId")) if timer != nil { - view.Main(view.TimerView(*timer)).Render(context.Background(), w) + view.Main(view.TimerView(*timer), currentUser).Render(context.Background(), w) } else { server.handleNotFound(w, r) } @@ -295,6 +354,55 @@ func (server *MyServer) handlePutTimer(w http.ResponseWriter, r *http.Request) { view.TimersList(timers).Render(context.Background(), w) } +func (server *MyServer) handlePostLogin(w http.ResponseWriter, r *http.Request) { + if server.findCurrentUser(w, r) != nil { + w.Header().Add("HX-Redirect", "/") + return + } + + userName := r.FormValue("user") + userPass := r.FormValue("password") + + user := queryUserByName(server.db, userName) + if user == nil { + w.WriteHeader(http.StatusBadRequest) + view.LoginFormError(nil, "Incorrect credentials").Render(context.Background(), w) + return + } + + err := bcrypt.CompareHashAndPassword(user.Password, []byte(user.Salt+userPass)) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + view.LoginFormError(nil, "Incorrect credentials").Render(context.Background(), w) + return + } + + sessionId, err := uuid.NewRandom() + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + view.LoginFormError(nil, "Internal server error").Render(context.Background(), w) + return + } + + cookie := http.Cookie{ + Name: SessionCookieName, + Value: sessionId.String(), + HttpOnly: true, + Secure: true, + } + server.sessions[sessionId.String()] = Session{UserId: user.Id} + http.SetCookie(w, &cookie) + w.Header().Add("HX-Redirect", "/") +} + +func (server *MyServer) handlePostLogout(w http.ResponseWriter, r *http.Request) { + if cookie, err := r.Cookie(SessionCookieName); err == nil { + delete(server.sessions, cookie.Value) + removeCookie(SessionCookieName, w) + } + w.Header().Add("HX-Redirect", "/") +} + func main() { log.Println("Starting...") @@ -313,6 +421,8 @@ func main() { fs := http.FileServer(http.Dir("static/")) http.Handle("GET /static/", http.StripPrefix("/static/", fs)) + http.HandleFunc("POST /login", myServer.handlePostLogin) + http.HandleFunc("POST /logout", myServer.handlePostLogout) http.HandleFunc("GET /timer/{timerId}", myServer.handleTimer) http.HandleFunc("POST /timer/{timerId}/addTime/{timeToAdd}", myServer.handleTimerAddTime) http.HandleFunc("DELETE /timer/{timerId}", myServer.handleDeleteTimer) -- cgit