diff options
-rw-r--r-- | database.go | 3 | ||||
-rw-r--r-- | model/timer.go | 96 | ||||
-rw-r--r-- | session.go | 84 | ||||
-rw-r--r-- | timer.db | bin | 28672 -> 28672 bytes | |||
-rw-r--r-- | timer.go | 302 | ||||
-rw-r--r-- | utils.go | 15 | ||||
-rw-r--r-- | utils/htmx.go (renamed from htmx.go) | 2 | ||||
-rw-r--r-- | utils/utils.go | 58 |
8 files changed, 313 insertions, 247 deletions
diff --git a/database.go b/database.go index f7cdf9b..583974f 100644 --- a/database.go +++ b/database.go @@ -7,6 +7,7 @@ import ( "golang.org/x/crypto/bcrypt" "stevenlr.com/timer/model" + "stevenlr.com/timer/utils" ) func initializeDatabaseV1(db *sql.DB) error { @@ -49,7 +50,7 @@ func initializeDatabaseV1(db *sql.DB) error { userName := "admin" userPassword := "admin" - salt, err := GenerateRandomString(33) + salt, err := utils.GenerateRandomString(33) if err != nil { return err } diff --git a/model/timer.go b/model/timer.go index 27e46da..3f13d0d 100644 --- a/model/timer.go +++ b/model/timer.go @@ -1,5 +1,17 @@ package model +import ( + "database/sql" + "log" + "time" + + "stevenlr.com/timer/utils" +) + +func GenerateTimerToken() (string, error) { + return utils.GenerateRandomString(66) +} + type Timer struct { Id UUID Name string @@ -12,3 +24,87 @@ type Timer struct { func (self Timer) IsFinished() bool { return MakeTimeNow().Compare(self.EndTime) >= 0 } + +func InsertTimer(tx *sql.Tx, name string, seconds int, ownerId UUID) error { + now := MakeTimeNow() + end := Time(time.Time(now).Add(time.Duration(seconds) * time.Second)) + id := MakeUUID() + token, _ := GenerateTimerToken() + _, err := tx.Exec(` + INSERT INTO Timer VALUES ($1, $2, $3, $4, $5, $6)`, id, name, now, end, ownerId, token) + return err +} + +func GetTimersForUser(db *sql.DB, owner UUID) []Timer { + rows, err := db.Query("SELECT Id, Name FROM Timer WHERE Owner=$1", owner) + if err != nil { + log.Fatalln(err) + } + + timers := []Timer{} + for rows.Next() { + var t Timer + if err := rows.Scan(&t.Id, &t.Name); err == nil { + timers = append(timers, t) + } + } + + return timers +} + +func GetTimerForUser(db *sql.DB, id UUID, userId UUID) *Timer { + row := db.QueryRow("SELECT Id, Name, StartTime, EndTime, Owner, Token FROM Timer WHERE Id=$1 AND Owner=$2", id, userId) + + var t Timer + if err := row.Scan(&t.Id, &t.Name, &t.StartTime, &t.EndTime, &t.Owner, &t.Token); err == nil { + return &t + } + + return nil +} + +func GetTimerWithToken(db *sql.DB, id UUID, token string) *Timer { + row := db.QueryRow("SELECT Id, Name, StartTime, EndTime, Owner, Token FROM Timer WHERE Id=$1 AND Token=$2", id, token) + + var t Timer + if err := row.Scan(&t.Id, &t.Name, &t.StartTime, &t.EndTime, &t.Owner, &t.Token); err == nil { + return &t + } + + return nil +} + +func DeleteTimer(db *sql.DB, id UUID, userId UUID) bool { + res, err := db.Exec("DELETE FROM Timer WHERE Id=$1 AND Owner=$2", id, userId) + if err != nil { + return false + } + + affected, err := res.RowsAffected() + return err == nil && affected == 1 +} + +func UpdateTimerEndTime(db *sql.DB, id UUID, endTime Time, userId UUID) bool { + res, err := db.Exec("UPDATE Timer SET EndTime=$1 WHERE Id=$2 AND Owner=$3", endTime, id, userId) + if err != nil { + return false + } + + affected, err := res.RowsAffected() + return err == nil && affected == 1 +} + +func RegenerateTimerToken(db *sql.DB, id UUID, userId UUID) bool { + newToken, err := GenerateTimerToken() + if err != nil { + return false + } + + res, err := db.Exec("UPDATE Timer SET Token=$1 WHERE Id=$2 AND Owner=$3", newToken, id, userId) + if err != nil { + return false + } + + affected, err := res.RowsAffected() + return err == nil && affected == 1 +} diff --git a/session.go b/session.go new file mode 100644 index 0000000..e32041f --- /dev/null +++ b/session.go @@ -0,0 +1,84 @@ +package main + +import ( + "database/sql" + "errors" + "net/http" + + "stevenlr.com/timer/model" + "stevenlr.com/timer/utils" +) + +func generateSessionId() (string, error) { + return utils.GenerateRandomString(66) +} + +type Sessions struct { + sessions map[string]Session +} + +type Session struct { + UserId model.UUID +} + +const sessionCookieName = "timerSession" + +func removeCookie(cookieName string, w http.ResponseWriter) { + cookie := http.Cookie{ + Name: cookieName, + Value: "", + MaxAge: -1, + } + http.SetCookie(w, &cookie) +} + +func MakeSessions() Sessions { + return Sessions{ + sessions: make(map[string]Session), + } +} + +func (sessions *Sessions) FindCurrentUser(db *sql.DB, w http.ResponseWriter, r *http.Request) *model.User { + cookie, err := r.Cookie(sessionCookieName) + if err != nil { + return nil + } + + userId, ok := sessions.sessions[cookie.Value] + if !ok { + removeCookie(sessionCookieName, w) + return nil + } + + user := model.GetUserById(db, userId.UserId) + if user == nil { + removeCookie(sessionCookieName, w) + } + + return user +} + +func (sessions *Sessions) StartSession(user model.UUID, w http.ResponseWriter) error { + sessionId, err := generateSessionId() + if err != nil { + return errors.New("Couldn't generate session ID") + } + + cookie := http.Cookie{ + Name: sessionCookieName, + Value: sessionId, + HttpOnly: true, + Secure: true, + } + + sessions.sessions[sessionId] = Session{UserId: user} + http.SetCookie(w, &cookie) + return nil +} + +func (sessions *Sessions) EndSession(w http.ResponseWriter, r *http.Request) { + if cookie, err := r.Cookie(sessionCookieName); err == nil { + delete(sessions.sessions, cookie.Value) + removeCookie(sessionCookieName, w) + } +} Binary files differ@@ -3,175 +3,40 @@ package main import ( "context" "database/sql" - "errors" "fmt" "log" "net/http" - "strconv" "strings" - "time" "golang.org/x/crypto/bcrypt" _ "github.com/mattn/go-sqlite3" "stevenlr.com/timer/model" + "stevenlr.com/timer/utils" "stevenlr.com/timer/view" ) -func generateSessionId() (string, error) { - return GenerateRandomString(66) -} - -func generateTimerToken() (string, error) { - return GenerateRandomString(66) -} - -func insertTimer(tx *sql.Tx, name string, seconds int, ownerId model.UUID) error { - now := model.MakeTimeNow() - end := model.Time(time.Time(now).Add(time.Duration(seconds) * time.Second)) - id := model.MakeUUID() - token, _ := generateTimerToken() - _, err := tx.Exec(` - INSERT INTO Timer VALUES ($1, $2, $3, $4, $5, $6)`, id, name, now, end, ownerId, token) - return err -} - -func queryAllTimers(db *sql.DB, owner model.UUID) []model.Timer { - rows, err := db.Query("SELECT Id, Name FROM Timer WHERE Owner=$1", owner) - if err != nil { - log.Fatalln(err) - } - - timers := []model.Timer{} - for rows.Next() { - var t model.Timer - if err := rows.Scan(&t.Id, &t.Name); err == nil { - timers = append(timers, t) - } - } - - return timers -} - -func queryTimerFromUser(db *sql.DB, idStr string, userId model.UUID) *model.Timer { - var id model.UUID - if err := id.Scan(idStr); err != nil { - return nil - } - - row := db.QueryRow("SELECT Id, Name, StartTime, EndTime, Owner, Token FROM Timer WHERE Id=$1 AND Owner=$2", id, userId) - - var t model.Timer - if err := row.Scan(&t.Id, &t.Name, &t.StartTime, &t.EndTime, &t.Owner, &t.Token); err == nil { - return &t - } - - return nil -} - -func queryTimerFromToken(db *sql.DB, idStr string, token string) *model.Timer { - var id model.UUID - if err := id.Scan(idStr); err != nil { - return nil - } - - row := db.QueryRow("SELECT Id, Name, StartTime, EndTime, Owner, Token FROM Timer WHERE Id=$1 AND Token=$2", id, token) - - var t model.Timer - if err := row.Scan(&t.Id, &t.Name, &t.StartTime, &t.EndTime, &t.Owner, &t.Token); err == nil { - return &t - } - - return nil -} - -func deleteTimer(db *sql.DB, idStr string, userId model.UUID) bool { - var id model.UUID - if err := id.Scan(idStr); err != nil { - return false - } - - res, err := db.Exec("DELETE FROM Timer WHERE Id=$1 AND Owner=$2", id, userId) - if err != nil { - return false - } - - affected, err := res.RowsAffected() - return err == nil && affected == 1 -} - -func updateTimerEndTime(db *sql.DB, id model.UUID, endTime model.Time, userId model.UUID) bool { - res, err := db.Exec("UPDATE Timer SET EndTime=$1 WHERE Id=$2 AND Owner=$3", endTime, id, userId) - if err != nil { - return false - } - - affected, err := res.RowsAffected() - return err == nil && affected == 1 -} - -func updateTimerToken(db *sql.DB, id model.UUID, token string, userId model.UUID) bool { - res, err := db.Exec("UPDATE Timer SET Token=$1 WHERE Id=$2 AND Owner=$3", token, id, userId) - if err != nil { - return false - } - - affected, err := res.RowsAffected() - return err == nil && affected == 1 -} - -type Session struct { - UserId model.UUID -} - -type MyServer struct { +type TimerServer struct { db *sql.DB - sessions map[string]Session -} - -const SessionCookieName = "timerSession" - -func removeCookie(cookieName string, w http.ResponseWriter) { - cookie := http.Cookie{ - Name: cookieName, - Value: "", - MaxAge: -1, - } - http.SetCookie(w, &cookie) + sessions Sessions } -func (server *MyServer) findCurrentUser(w http.ResponseWriter, r *http.Request) *model.User { - cookie, err := r.Cookie(SessionCookieName) - if err != nil { - return nil - } - - userId, ok := server.sessions[cookie.Value] - if !ok { - removeCookie(SessionCookieName, w) - return nil - } - - user := model.GetUserById(server.db, userId.UserId) - if user == nil { - removeCookie(SessionCookieName, w) - } - - return user +func (server *TimerServer) findCurrentUser(w http.ResponseWriter, r *http.Request) *model.User { + return server.sessions.FindCurrentUser(server.db, w, r) } -func (server *MyServer) handleNotFound(w http.ResponseWriter, _ *http.Request) { +func (server *TimerServer) handleNotFound(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusNotFound) view.Error404().Render(context.Background(), w) } -func (server *MyServer) handleMain(w http.ResponseWriter, r *http.Request) { +func (server *TimerServer) handleMain(w http.ResponseWriter, r *http.Request) { currentUser := server.findCurrentUser(w, r) if r.URL.Path == "/" { timers := make([]model.Timer, 0) if currentUser != nil { - timers = queryAllTimers(server.db, currentUser.Id) + timers = model.GetTimersForUser(server.db, currentUser.Id) } view.Main(view.TimersList(timers, currentUser != nil), currentUser).Render(context.Background(), w) } else { @@ -179,14 +44,20 @@ func (server *MyServer) handleMain(w http.ResponseWriter, r *http.Request) { } } -func (server *MyServer) handleTimer(w http.ResponseWriter, r *http.Request) { +func (server *TimerServer) handleTimer(w http.ResponseWriter, r *http.Request) { currentUser := server.findCurrentUser(w, r) if currentUser == nil { server.handleNotFound(w, r) return } - timer := queryTimerFromUser(server.db, r.PathValue("timerId"), currentUser.Id) + var id model.UUID + if err := id.Scan(r.PathValue("timerId")); err != nil { + server.handleNotFound(w, r) + return + } + + timer := model.GetTimerForUser(server.db, id, currentUser.Id) if timer != nil && timer.Owner == currentUser.Id { view.Main(view.TimerView(*timer), currentUser).Render(context.Background(), w) } else { @@ -194,44 +65,14 @@ func (server *MyServer) handleTimer(w http.ResponseWriter, r *http.Request) { } } -func parseDuration(value string) (time.Duration, error) { - const nullDuration = time.Duration(0) - if len(value) == 0 { - return nullDuration, errors.New("Empty duration string") - } - - var unit time.Duration - switch value[len(value)-1] { - case 's': - unit = time.Second - case 'm': - unit = time.Minute - case 'h': - unit = time.Hour - case 'd': - unit = time.Duration(24) * time.Hour - case 'w': - unit = time.Duration(24*7) * time.Hour - default: - return nullDuration, errors.New("Invalid duration format") - } - - amount, err := strconv.ParseInt(value[0:len(value)-1], 10, 64) - if err != nil || amount < 0 { - return nullDuration, errors.New("Invalid duration value") - } - - return time.Duration(amount) * unit, nil -} - -func (server *MyServer) handleTimerAddTimeCommon(w http.ResponseWriter, r *http.Request, timer *model.Timer) bool { +func (server *TimerServer) handleTimerAddTimeCommon(w http.ResponseWriter, r *http.Request, timer *model.Timer) bool { if timer.IsFinished() { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("Timer already finished")) return false } - duration, err := parseDuration(r.FormValue("timeToAdd")) + duration, err := utils.ParseDuration(r.FormValue("timeToAdd")) if err != nil { w.WriteHeader(http.StatusBadRequest) w.Write([]byte(err.Error())) @@ -239,7 +80,7 @@ func (server *MyServer) handleTimerAddTimeCommon(w http.ResponseWriter, r *http. } timer.EndTime.Add(duration) - res := updateTimerEndTime(server.db, timer.Id, timer.EndTime, timer.Owner) + res := model.UpdateTimerEndTime(server.db, timer.Id, timer.EndTime, timer.Owner) if !res { w.WriteHeader(http.StatusInternalServerError) return false @@ -248,14 +89,20 @@ func (server *MyServer) handleTimerAddTimeCommon(w http.ResponseWriter, r *http. return true } -func (server *MyServer) handleTimerAddTime(w http.ResponseWriter, r *http.Request) { +func (server *TimerServer) handleTimerAddTime(w http.ResponseWriter, r *http.Request) { currentUser := server.findCurrentUser(w, r) if currentUser == nil { w.WriteHeader(http.StatusUnauthorized) return } - timer := queryTimerFromUser(server.db, r.PathValue("timerId"), currentUser.Id) + var id model.UUID + if err := id.Scan(r.PathValue("timerId")); err != nil { + w.WriteHeader(http.StatusNotFound) + return + } + + timer := model.GetTimerForUser(server.db, id, currentUser.Id) if timer == nil { w.WriteHeader(http.StatusNotFound) return @@ -268,8 +115,14 @@ func (server *MyServer) handleTimerAddTime(w http.ResponseWriter, r *http.Reques view.TimerInfo(*timer).Render(context.Background(), w) } -func (server *MyServer) handleApiTimerAddTime(w http.ResponseWriter, r *http.Request) { - timer := queryTimerFromToken(server.db, r.PathValue("timerId"), r.FormValue("token")) +func (server *TimerServer) handleApiTimerAddTime(w http.ResponseWriter, r *http.Request) { + var id model.UUID + if err := id.Scan(r.PathValue("timerId")); err != nil { + w.WriteHeader(http.StatusNotFound) + return + } + + timer := model.GetTimerWithToken(server.db, id, r.FormValue("token")) if timer == nil { w.WriteHeader(http.StatusNotFound) return @@ -280,14 +133,20 @@ func (server *MyServer) handleApiTimerAddTime(w http.ResponseWriter, r *http.Req } } -func (server *MyServer) handleGetTimerToken(w http.ResponseWriter, r *http.Request) { +func (server *TimerServer) handleGetTimerToken(w http.ResponseWriter, r *http.Request) { currentUser := server.findCurrentUser(w, r) if currentUser == nil { w.WriteHeader(http.StatusUnauthorized) return } - timer := queryTimerFromUser(server.db, r.PathValue("timerId"), currentUser.Id) + var id model.UUID + if err := id.Scan(r.PathValue("timerId")); err != nil { + w.WriteHeader(http.StatusNotFound) + return + } + + timer := model.GetTimerForUser(server.db, id, currentUser.Id) if timer == nil { server.handleNotFound(w, r) return @@ -296,27 +155,26 @@ func (server *MyServer) handleGetTimerToken(w http.ResponseWriter, r *http.Reque w.Write([]byte(fmt.Sprint("<code>", timer.Token, "</code>"))) } -func (server *MyServer) handleResetTimerToken(w http.ResponseWriter, r *http.Request) { +func (server *TimerServer) handleResetTimerToken(w http.ResponseWriter, r *http.Request) { currentUser := server.findCurrentUser(w, r) if currentUser == nil { w.WriteHeader(http.StatusUnauthorized) return } - timer := queryTimerFromUser(server.db, r.PathValue("timerId"), currentUser.Id) - if timer == nil { - server.handleNotFound(w, r) + var id model.UUID + if err := id.Scan(r.PathValue("timerId")); err != nil { + w.WriteHeader(http.StatusNotFound) return } - newToken, err := generateTimerToken() - if err != nil { - w.WriteHeader(http.StatusInternalServerError) + timer := model.GetTimerForUser(server.db, id, currentUser.Id) + if timer == nil { + w.WriteHeader(http.StatusNotFound) return } - timer.Token = newToken - res := updateTimerToken(server.db, timer.Id, newToken, currentUser.Id) + res := model.RegenerateTimerToken(server.db, timer.Id, currentUser.Id) if !res { w.WriteHeader(http.StatusInternalServerError) return @@ -325,29 +183,26 @@ func (server *MyServer) handleResetTimerToken(w http.ResponseWriter, r *http.Req view.TimerTokenForm(*timer).Render(context.Background(), w) } -func (server *MyServer) handleDeleteTimer(w http.ResponseWriter, r *http.Request) { +func (server *TimerServer) handleDeleteTimer(w http.ResponseWriter, r *http.Request) { user := server.findCurrentUser(w, r) if user == nil { w.WriteHeader(http.StatusUnauthorized) return } - success := deleteTimer(server.db, r.PathValue("timerId"), user.Id) - if !success { + var id model.UUID + if err := id.Scan(r.PathValue("timerId")); err != nil { w.WriteHeader(http.StatusNotFound) + return } -} -func parseNumber(s string) (int64, error) { - s = strings.TrimSpace(s) - if len(s) == 0 { - s = "0" + success := model.DeleteTimer(server.db, id, user.Id) + if !success { + w.WriteHeader(http.StatusNotFound) } - - return strconv.ParseInt(s, 10, 64) } -func (server *MyServer) handleCreateTimer(w http.ResponseWriter, r *http.Request) { +func (server *TimerServer) handleCreateTimer(w http.ResponseWriter, r *http.Request) { timerName := strings.TrimSpace(r.FormValue("timerName")) user := server.findCurrentUser(w, r) @@ -357,14 +212,14 @@ func (server *MyServer) handleCreateTimer(w http.ResponseWriter, r *http.Request return } - days, err := parseNumber(r.FormValue("days")) + days, err := utils.ParseNumber(r.FormValue("days")) if err != nil { w.WriteHeader(http.StatusBadRequest) view.TimerCreateForm(timerName, "Error parsing days").Render(context.Background(), w) return } - hours, err := parseNumber(r.FormValue("hours")) + hours, err := utils.ParseNumber(r.FormValue("hours")) if err != nil { w.WriteHeader(http.StatusBadRequest) view.TimerCreateForm(timerName, "Error parsing hours").Render(context.Background(), w) @@ -385,7 +240,7 @@ func (server *MyServer) handleCreateTimer(w http.ResponseWriter, r *http.Request return } - err = insertTimer(tx, timerName, int(((max(days, 0)*24)+max(hours, 0))*3600), user.Id) + err = model.InsertTimer(tx, timerName, int(((max(days, 0)*24)+max(hours, 0))*3600), user.Id) if err != nil { w.WriteHeader(http.StatusInternalServerError) view.TimerCreateForm(timerName, "Internal server error").Render(context.Background(), w) @@ -394,13 +249,13 @@ func (server *MyServer) handleCreateTimer(w http.ResponseWriter, r *http.Request tx.Commit() - timers := queryAllTimers(server.db, user.Id) + timers := model.GetTimersForUser(server.db, user.Id) view.TimersList(timers, user != nil).Render(context.Background(), w) } -func (server *MyServer) handlePostLogin(w http.ResponseWriter, r *http.Request) { +func (server *TimerServer) handlePostLogin(w http.ResponseWriter, r *http.Request) { if server.findCurrentUser(w, r) != nil { - HtmxRedirect(w, "/") + utils.HtmxRedirect(w, "/") return } @@ -421,30 +276,17 @@ func (server *MyServer) handlePostLogin(w http.ResponseWriter, r *http.Request) return } - sessionId, err := generateSessionId() - if err != nil { + if err := server.sessions.StartSession(user.Id, w); err == nil { + utils.HtmxRedirect(w, "/") + } else { w.WriteHeader(http.StatusInternalServerError) view.LoginFormError(nil, "Internal server error").Render(context.Background(), w) - return } - - cookie := http.Cookie{ - Name: SessionCookieName, - Value: sessionId, - HttpOnly: true, - Secure: true, - } - server.sessions[sessionId] = Session{UserId: user.Id} - http.SetCookie(w, &cookie) - HtmxRedirect(w, "/") } -func (server *MyServer) handlePostLogout(w http.ResponseWriter, r *http.Request) { - if cookie, err := r.Cookie(SessionCookieName); err == nil { - delete(server.sessions, cookie.Value) - removeCookie(SessionCookieName, w) - } - HtmxRedirect(w, "/") +func (server *TimerServer) handlePostLogout(w http.ResponseWriter, r *http.Request) { + server.sessions.EndSession(w, r) + utils.HtmxRedirect(w, "/") } func main() { @@ -460,7 +302,7 @@ func main() { log.Fatalln(err) } - myServer := MyServer{db: db, sessions: make(map[string]Session)} + myServer := TimerServer{db: db, sessions: MakeSessions()} fs := http.FileServer(http.Dir("static/")) http.Handle("GET /static/", http.StripPrefix("/static/", fs)) diff --git a/utils.go b/utils.go deleted file mode 100644 index 5aa2894..0000000 --- a/utils.go +++ /dev/null @@ -1,15 +0,0 @@ -package main - -import ( - "crypto/rand" - "encoding/base64" -) - -func GenerateRandomString(len int) (string, error) { - bin := make([]byte, len) - _, err := rand.Read(bin) - if err != nil { - return "", err - } - return base64.StdEncoding.EncodeToString(bin), nil -} @@ -1,4 +1,4 @@ -package main +package utils import ( "net/http" diff --git a/utils/utils.go b/utils/utils.go new file mode 100644 index 0000000..607236d --- /dev/null +++ b/utils/utils.go @@ -0,0 +1,58 @@ +package utils + +import ( + "crypto/rand" + "encoding/base64" + "errors" + "strconv" + "strings" + "time" +) + +func GenerateRandomString(len int) (string, error) { + bin := make([]byte, len) + _, err := rand.Read(bin) + if err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(bin), nil +} + +func ParseNumber(s string) (int64, error) { + s = strings.TrimSpace(s) + if len(s) == 0 { + s = "0" + } + + return strconv.ParseInt(s, 10, 64) +} + +func ParseDuration(value string) (time.Duration, error) { + const nullDuration = time.Duration(0) + if len(value) == 0 { + return nullDuration, errors.New("Empty duration string") + } + + var unit time.Duration + switch value[len(value)-1] { + case 's': + unit = time.Second + case 'm': + unit = time.Minute + case 'h': + unit = time.Hour + case 'd': + unit = time.Duration(24) * time.Hour + case 'w': + unit = time.Duration(24*7) * time.Hour + default: + return nullDuration, errors.New("Invalid duration format") + } + + amount, err := strconv.ParseInt(value[0:len(value)-1], 10, 64) + if err != nil || amount < 0 { + return nullDuration, errors.New("Invalid duration value") + } + + return time.Duration(amount) * unit, nil +} |