diff options
Diffstat (limited to 'timer.go')
-rw-r--r-- | timer.go | 302 |
1 files changed, 72 insertions, 230 deletions
@@ -3,175 +3,40 @@ package main import ( "context" "database/sql" - "errors" "fmt" "log" "net/http" - "strconv" "strings" - "time" "golang.org/x/crypto/bcrypt" _ "github.com/mattn/go-sqlite3" "stevenlr.com/timer/model" + "stevenlr.com/timer/utils" "stevenlr.com/timer/view" ) -func generateSessionId() (string, error) { - return GenerateRandomString(66) -} - -func generateTimerToken() (string, error) { - return GenerateRandomString(66) -} - -func insertTimer(tx *sql.Tx, name string, seconds int, ownerId model.UUID) error { - now := model.MakeTimeNow() - end := model.Time(time.Time(now).Add(time.Duration(seconds) * time.Second)) - id := model.MakeUUID() - token, _ := generateTimerToken() - _, err := tx.Exec(` - INSERT INTO Timer VALUES ($1, $2, $3, $4, $5, $6)`, id, name, now, end, ownerId, token) - return err -} - -func queryAllTimers(db *sql.DB, owner model.UUID) []model.Timer { - rows, err := db.Query("SELECT Id, Name FROM Timer WHERE Owner=$1", owner) - if err != nil { - log.Fatalln(err) - } - - timers := []model.Timer{} - for rows.Next() { - var t model.Timer - if err := rows.Scan(&t.Id, &t.Name); err == nil { - timers = append(timers, t) - } - } - - return timers -} - -func queryTimerFromUser(db *sql.DB, idStr string, userId model.UUID) *model.Timer { - var id model.UUID - if err := id.Scan(idStr); err != nil { - return nil - } - - row := db.QueryRow("SELECT Id, Name, StartTime, EndTime, Owner, Token FROM Timer WHERE Id=$1 AND Owner=$2", id, userId) - - var t model.Timer - if err := row.Scan(&t.Id, &t.Name, &t.StartTime, &t.EndTime, &t.Owner, &t.Token); err == nil { - return &t - } - - return nil -} - -func queryTimerFromToken(db *sql.DB, idStr string, token string) *model.Timer { - var id model.UUID - if err := id.Scan(idStr); err != nil { - return nil - } - - row := db.QueryRow("SELECT Id, Name, StartTime, EndTime, Owner, Token FROM Timer WHERE Id=$1 AND Token=$2", id, token) - - var t model.Timer - if err := row.Scan(&t.Id, &t.Name, &t.StartTime, &t.EndTime, &t.Owner, &t.Token); err == nil { - return &t - } - - return nil -} - -func deleteTimer(db *sql.DB, idStr string, userId model.UUID) bool { - var id model.UUID - if err := id.Scan(idStr); err != nil { - return false - } - - res, err := db.Exec("DELETE FROM Timer WHERE Id=$1 AND Owner=$2", id, userId) - if err != nil { - return false - } - - affected, err := res.RowsAffected() - return err == nil && affected == 1 -} - -func updateTimerEndTime(db *sql.DB, id model.UUID, endTime model.Time, userId model.UUID) bool { - res, err := db.Exec("UPDATE Timer SET EndTime=$1 WHERE Id=$2 AND Owner=$3", endTime, id, userId) - if err != nil { - return false - } - - affected, err := res.RowsAffected() - return err == nil && affected == 1 -} - -func updateTimerToken(db *sql.DB, id model.UUID, token string, userId model.UUID) bool { - res, err := db.Exec("UPDATE Timer SET Token=$1 WHERE Id=$2 AND Owner=$3", token, id, userId) - if err != nil { - return false - } - - affected, err := res.RowsAffected() - return err == nil && affected == 1 -} - -type Session struct { - UserId model.UUID -} - -type MyServer struct { +type TimerServer struct { db *sql.DB - sessions map[string]Session -} - -const SessionCookieName = "timerSession" - -func removeCookie(cookieName string, w http.ResponseWriter) { - cookie := http.Cookie{ - Name: cookieName, - Value: "", - MaxAge: -1, - } - http.SetCookie(w, &cookie) + sessions Sessions } -func (server *MyServer) findCurrentUser(w http.ResponseWriter, r *http.Request) *model.User { - cookie, err := r.Cookie(SessionCookieName) - if err != nil { - return nil - } - - userId, ok := server.sessions[cookie.Value] - if !ok { - removeCookie(SessionCookieName, w) - return nil - } - - user := model.GetUserById(server.db, userId.UserId) - if user == nil { - removeCookie(SessionCookieName, w) - } - - return user +func (server *TimerServer) findCurrentUser(w http.ResponseWriter, r *http.Request) *model.User { + return server.sessions.FindCurrentUser(server.db, w, r) } -func (server *MyServer) handleNotFound(w http.ResponseWriter, _ *http.Request) { +func (server *TimerServer) handleNotFound(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusNotFound) view.Error404().Render(context.Background(), w) } -func (server *MyServer) handleMain(w http.ResponseWriter, r *http.Request) { +func (server *TimerServer) handleMain(w http.ResponseWriter, r *http.Request) { currentUser := server.findCurrentUser(w, r) if r.URL.Path == "/" { timers := make([]model.Timer, 0) if currentUser != nil { - timers = queryAllTimers(server.db, currentUser.Id) + timers = model.GetTimersForUser(server.db, currentUser.Id) } view.Main(view.TimersList(timers, currentUser != nil), currentUser).Render(context.Background(), w) } else { @@ -179,14 +44,20 @@ func (server *MyServer) handleMain(w http.ResponseWriter, r *http.Request) { } } -func (server *MyServer) handleTimer(w http.ResponseWriter, r *http.Request) { +func (server *TimerServer) handleTimer(w http.ResponseWriter, r *http.Request) { currentUser := server.findCurrentUser(w, r) if currentUser == nil { server.handleNotFound(w, r) return } - timer := queryTimerFromUser(server.db, r.PathValue("timerId"), currentUser.Id) + var id model.UUID + if err := id.Scan(r.PathValue("timerId")); err != nil { + server.handleNotFound(w, r) + return + } + + timer := model.GetTimerForUser(server.db, id, currentUser.Id) if timer != nil && timer.Owner == currentUser.Id { view.Main(view.TimerView(*timer), currentUser).Render(context.Background(), w) } else { @@ -194,44 +65,14 @@ func (server *MyServer) handleTimer(w http.ResponseWriter, r *http.Request) { } } -func parseDuration(value string) (time.Duration, error) { - const nullDuration = time.Duration(0) - if len(value) == 0 { - return nullDuration, errors.New("Empty duration string") - } - - var unit time.Duration - switch value[len(value)-1] { - case 's': - unit = time.Second - case 'm': - unit = time.Minute - case 'h': - unit = time.Hour - case 'd': - unit = time.Duration(24) * time.Hour - case 'w': - unit = time.Duration(24*7) * time.Hour - default: - return nullDuration, errors.New("Invalid duration format") - } - - amount, err := strconv.ParseInt(value[0:len(value)-1], 10, 64) - if err != nil || amount < 0 { - return nullDuration, errors.New("Invalid duration value") - } - - return time.Duration(amount) * unit, nil -} - -func (server *MyServer) handleTimerAddTimeCommon(w http.ResponseWriter, r *http.Request, timer *model.Timer) bool { +func (server *TimerServer) handleTimerAddTimeCommon(w http.ResponseWriter, r *http.Request, timer *model.Timer) bool { if timer.IsFinished() { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("Timer already finished")) return false } - duration, err := parseDuration(r.FormValue("timeToAdd")) + duration, err := utils.ParseDuration(r.FormValue("timeToAdd")) if err != nil { w.WriteHeader(http.StatusBadRequest) w.Write([]byte(err.Error())) @@ -239,7 +80,7 @@ func (server *MyServer) handleTimerAddTimeCommon(w http.ResponseWriter, r *http. } timer.EndTime.Add(duration) - res := updateTimerEndTime(server.db, timer.Id, timer.EndTime, timer.Owner) + res := model.UpdateTimerEndTime(server.db, timer.Id, timer.EndTime, timer.Owner) if !res { w.WriteHeader(http.StatusInternalServerError) return false @@ -248,14 +89,20 @@ func (server *MyServer) handleTimerAddTimeCommon(w http.ResponseWriter, r *http. return true } -func (server *MyServer) handleTimerAddTime(w http.ResponseWriter, r *http.Request) { +func (server *TimerServer) handleTimerAddTime(w http.ResponseWriter, r *http.Request) { currentUser := server.findCurrentUser(w, r) if currentUser == nil { w.WriteHeader(http.StatusUnauthorized) return } - timer := queryTimerFromUser(server.db, r.PathValue("timerId"), currentUser.Id) + var id model.UUID + if err := id.Scan(r.PathValue("timerId")); err != nil { + w.WriteHeader(http.StatusNotFound) + return + } + + timer := model.GetTimerForUser(server.db, id, currentUser.Id) if timer == nil { w.WriteHeader(http.StatusNotFound) return @@ -268,8 +115,14 @@ func (server *MyServer) handleTimerAddTime(w http.ResponseWriter, r *http.Reques view.TimerInfo(*timer).Render(context.Background(), w) } -func (server *MyServer) handleApiTimerAddTime(w http.ResponseWriter, r *http.Request) { - timer := queryTimerFromToken(server.db, r.PathValue("timerId"), r.FormValue("token")) +func (server *TimerServer) handleApiTimerAddTime(w http.ResponseWriter, r *http.Request) { + var id model.UUID + if err := id.Scan(r.PathValue("timerId")); err != nil { + w.WriteHeader(http.StatusNotFound) + return + } + + timer := model.GetTimerWithToken(server.db, id, r.FormValue("token")) if timer == nil { w.WriteHeader(http.StatusNotFound) return @@ -280,14 +133,20 @@ func (server *MyServer) handleApiTimerAddTime(w http.ResponseWriter, r *http.Req } } -func (server *MyServer) handleGetTimerToken(w http.ResponseWriter, r *http.Request) { +func (server *TimerServer) handleGetTimerToken(w http.ResponseWriter, r *http.Request) { currentUser := server.findCurrentUser(w, r) if currentUser == nil { w.WriteHeader(http.StatusUnauthorized) return } - timer := queryTimerFromUser(server.db, r.PathValue("timerId"), currentUser.Id) + var id model.UUID + if err := id.Scan(r.PathValue("timerId")); err != nil { + w.WriteHeader(http.StatusNotFound) + return + } + + timer := model.GetTimerForUser(server.db, id, currentUser.Id) if timer == nil { server.handleNotFound(w, r) return @@ -296,27 +155,26 @@ func (server *MyServer) handleGetTimerToken(w http.ResponseWriter, r *http.Reque w.Write([]byte(fmt.Sprint("<code>", timer.Token, "</code>"))) } -func (server *MyServer) handleResetTimerToken(w http.ResponseWriter, r *http.Request) { +func (server *TimerServer) handleResetTimerToken(w http.ResponseWriter, r *http.Request) { currentUser := server.findCurrentUser(w, r) if currentUser == nil { w.WriteHeader(http.StatusUnauthorized) return } - timer := queryTimerFromUser(server.db, r.PathValue("timerId"), currentUser.Id) - if timer == nil { - server.handleNotFound(w, r) + var id model.UUID + if err := id.Scan(r.PathValue("timerId")); err != nil { + w.WriteHeader(http.StatusNotFound) return } - newToken, err := generateTimerToken() - if err != nil { - w.WriteHeader(http.StatusInternalServerError) + timer := model.GetTimerForUser(server.db, id, currentUser.Id) + if timer == nil { + w.WriteHeader(http.StatusNotFound) return } - timer.Token = newToken - res := updateTimerToken(server.db, timer.Id, newToken, currentUser.Id) + res := model.RegenerateTimerToken(server.db, timer.Id, currentUser.Id) if !res { w.WriteHeader(http.StatusInternalServerError) return @@ -325,29 +183,26 @@ func (server *MyServer) handleResetTimerToken(w http.ResponseWriter, r *http.Req view.TimerTokenForm(*timer).Render(context.Background(), w) } -func (server *MyServer) handleDeleteTimer(w http.ResponseWriter, r *http.Request) { +func (server *TimerServer) handleDeleteTimer(w http.ResponseWriter, r *http.Request) { user := server.findCurrentUser(w, r) if user == nil { w.WriteHeader(http.StatusUnauthorized) return } - success := deleteTimer(server.db, r.PathValue("timerId"), user.Id) - if !success { + var id model.UUID + if err := id.Scan(r.PathValue("timerId")); err != nil { w.WriteHeader(http.StatusNotFound) + return } -} -func parseNumber(s string) (int64, error) { - s = strings.TrimSpace(s) - if len(s) == 0 { - s = "0" + success := model.DeleteTimer(server.db, id, user.Id) + if !success { + w.WriteHeader(http.StatusNotFound) } - - return strconv.ParseInt(s, 10, 64) } -func (server *MyServer) handleCreateTimer(w http.ResponseWriter, r *http.Request) { +func (server *TimerServer) handleCreateTimer(w http.ResponseWriter, r *http.Request) { timerName := strings.TrimSpace(r.FormValue("timerName")) user := server.findCurrentUser(w, r) @@ -357,14 +212,14 @@ func (server *MyServer) handleCreateTimer(w http.ResponseWriter, r *http.Request return } - days, err := parseNumber(r.FormValue("days")) + days, err := utils.ParseNumber(r.FormValue("days")) if err != nil { w.WriteHeader(http.StatusBadRequest) view.TimerCreateForm(timerName, "Error parsing days").Render(context.Background(), w) return } - hours, err := parseNumber(r.FormValue("hours")) + hours, err := utils.ParseNumber(r.FormValue("hours")) if err != nil { w.WriteHeader(http.StatusBadRequest) view.TimerCreateForm(timerName, "Error parsing hours").Render(context.Background(), w) @@ -385,7 +240,7 @@ func (server *MyServer) handleCreateTimer(w http.ResponseWriter, r *http.Request return } - err = insertTimer(tx, timerName, int(((max(days, 0)*24)+max(hours, 0))*3600), user.Id) + err = model.InsertTimer(tx, timerName, int(((max(days, 0)*24)+max(hours, 0))*3600), user.Id) if err != nil { w.WriteHeader(http.StatusInternalServerError) view.TimerCreateForm(timerName, "Internal server error").Render(context.Background(), w) @@ -394,13 +249,13 @@ func (server *MyServer) handleCreateTimer(w http.ResponseWriter, r *http.Request tx.Commit() - timers := queryAllTimers(server.db, user.Id) + timers := model.GetTimersForUser(server.db, user.Id) view.TimersList(timers, user != nil).Render(context.Background(), w) } -func (server *MyServer) handlePostLogin(w http.ResponseWriter, r *http.Request) { +func (server *TimerServer) handlePostLogin(w http.ResponseWriter, r *http.Request) { if server.findCurrentUser(w, r) != nil { - HtmxRedirect(w, "/") + utils.HtmxRedirect(w, "/") return } @@ -421,30 +276,17 @@ func (server *MyServer) handlePostLogin(w http.ResponseWriter, r *http.Request) return } - sessionId, err := generateSessionId() - if err != nil { + if err := server.sessions.StartSession(user.Id, w); err == nil { + utils.HtmxRedirect(w, "/") + } else { w.WriteHeader(http.StatusInternalServerError) view.LoginFormError(nil, "Internal server error").Render(context.Background(), w) - return } - - cookie := http.Cookie{ - Name: SessionCookieName, - Value: sessionId, - HttpOnly: true, - Secure: true, - } - server.sessions[sessionId] = Session{UserId: user.Id} - http.SetCookie(w, &cookie) - HtmxRedirect(w, "/") } -func (server *MyServer) handlePostLogout(w http.ResponseWriter, r *http.Request) { - if cookie, err := r.Cookie(SessionCookieName); err == nil { - delete(server.sessions, cookie.Value) - removeCookie(SessionCookieName, w) - } - HtmxRedirect(w, "/") +func (server *TimerServer) handlePostLogout(w http.ResponseWriter, r *http.Request) { + server.sessions.EndSession(w, r) + utils.HtmxRedirect(w, "/") } func main() { @@ -460,7 +302,7 @@ func main() { log.Fatalln(err) } - myServer := MyServer{db: db, sessions: make(map[string]Session)} + myServer := TimerServer{db: db, sessions: MakeSessions()} fs := http.FileServer(http.Dir("static/")) http.Handle("GET /static/", http.StripPrefix("/static/", fs)) |