summaryrefslogtreecommitdiff
path: root/timer.go
diff options
context:
space:
mode:
Diffstat (limited to 'timer.go')
-rw-r--r--timer.go105
1 files changed, 66 insertions, 39 deletions
diff --git a/timer.go b/timer.go
index 72d4094..ef65b8b 100644
--- a/timer.go
+++ b/timer.go
@@ -2,7 +2,9 @@ package main
import (
"context"
+ "crypto/rand"
"database/sql"
+ "encoding/base64"
"errors"
"log"
"net/http"
@@ -12,20 +14,18 @@ import (
"golang.org/x/crypto/bcrypt"
- "github.com/google/uuid"
-
_ "github.com/mattn/go-sqlite3"
"stevenlr.com/timer/model"
"stevenlr.com/timer/view"
)
-func insertTimer(tx *sql.Tx, name string, seconds int) error {
+func insertTimer(tx *sql.Tx, name string, seconds int, ownerId model.UUID) error {
now := model.MakeTimeNow()
end := model.Time(time.Time(now).Add(time.Duration(seconds) * time.Second))
id := model.MakeUUID()
_, err := tx.Exec(`
- INSERT INTO Timer VALUES ($1, $2, $3, $4)`, id, name, now, end)
+ INSERT INTO Timer VALUES ($1, $2, $3, $4, $5)`, id, name, now, end, ownerId)
return err
}
@@ -47,18 +47,16 @@ func initializeDatabase(db *sql.DB) error {
Name TEXT NOT NULL,
StartTime TEXT NOT NULL,
EndTime TEXT NOT NULL,
+ Owner BLOB NOT NULL,
PRIMARY KEY (id)
)`)
if err != nil {
return err
}
- err = insertTimer(tx, "My timer", 600)
- if err != nil {
- return err
- }
-
- err = insertTimer(tx, "My timer2", 600)
+ userUuidStr := "7015cee7-89a5-4057-b7c9-7e0128ad5086"
+ var userId model.UUID
+ err = userId.Scan(userUuidStr)
if err != nil {
return err
}
@@ -75,13 +73,6 @@ func initializeDatabase(db *sql.DB) error {
return err
}
- userUuidStr := "7015cee7-89a5-4057-b7c9-7e0128ad5086"
- var userId model.UUID
- err = userId.Scan(userUuidStr)
- if err != nil {
- return err
- }
-
userPasswordClear := "steven"
password, err := bcrypt.GenerateFromPassword([]byte(userUuidStr+userPasswordClear), bcrypt.MinCost)
@@ -97,8 +88,8 @@ func initializeDatabase(db *sql.DB) error {
return tx.Commit()
}
-func queryAllTimers(db *sql.DB) []model.Timer {
- rows, err := db.Query("SELECT Id, Name FROM Timer")
+func queryAllTimers(db *sql.DB, owner model.UUID) []model.Timer {
+ rows, err := db.Query("SELECT Id, Name FROM Timer WHERE Owner=$1", owner)
if err != nil {
log.Fatalln(err)
}
@@ -138,29 +129,29 @@ func queryUserById(db *sql.DB, id model.UUID) *model.User {
return &user
}
-func queryTimer(db *sql.DB, idStr string) *model.Timer {
+func queryTimer(db *sql.DB, idStr string, userId model.UUID) *model.Timer {
var id model.UUID
if err := id.Scan(idStr); err != nil {
return nil
}
- row := db.QueryRow("SELECT Id, Name, StartTime, EndTime FROM Timer WHERE Id=$1", id)
+ row := db.QueryRow("SELECT Id, Name, StartTime, EndTime, Owner FROM Timer WHERE Id=$1 AND Owner=$2", id, userId)
var t model.Timer
- if err := row.Scan(&t.Id, &t.Name, &t.StartTime, &t.EndTime); err == nil {
+ if err := row.Scan(&t.Id, &t.Name, &t.StartTime, &t.EndTime, &t.Owner); err == nil {
return &t
}
return nil
}
-func deleteTimer(db *sql.DB, idStr string) bool {
+func deleteTimer(db *sql.DB, idStr string, userId model.UUID) bool {
var id model.UUID
if err := id.Scan(idStr); err != nil {
return false
}
- res, err := db.Exec("DELETE FROM Timer WHERE Id=$1", id)
+ res, err := db.Exec("DELETE FROM Timer WHERE Id=$1 AND Owner=$2", id, userId)
if err != nil {
return false
}
@@ -169,8 +160,8 @@ func deleteTimer(db *sql.DB, idStr string) bool {
return err == nil && affected == 1
}
-func updateTimerEndTime(db *sql.DB, id model.UUID, endTime model.Time) bool {
- res, err := db.Exec("UPDATE Timer SET EndTime=$1 WHERE Id=$2", endTime, id)
+func updateTimerEndTime(db *sql.DB, id model.UUID, endTime model.Time, userId model.UUID) bool {
+ res, err := db.Exec("UPDATE Timer SET EndTime=$1 WHERE Id=$2 AND Owner=$3", endTime, id, userId)
if err != nil {
return false
}
@@ -227,8 +218,11 @@ func (server *MyServer) handleNotFound(w http.ResponseWriter, _ *http.Request) {
func (server *MyServer) handleMain(w http.ResponseWriter, r *http.Request) {
currentUser := server.findCurrentUser(w, r)
if r.URL.Path == "/" {
- timers := queryAllTimers(server.db)
- view.Main(view.TimersList(timers), currentUser).Render(context.Background(), w)
+ timers := make([]model.Timer, 0)
+ if currentUser != nil {
+ timers = queryAllTimers(server.db, currentUser.Id)
+ }
+ view.Main(view.TimersList(timers, currentUser != nil), currentUser).Render(context.Background(), w)
} else {
server.handleNotFound(w, r)
}
@@ -236,8 +230,13 @@ func (server *MyServer) handleMain(w http.ResponseWriter, r *http.Request) {
func (server *MyServer) handleTimer(w http.ResponseWriter, r *http.Request) {
currentUser := server.findCurrentUser(w, r)
- timer := queryTimer(server.db, r.PathValue("timerId"))
- if timer != nil {
+ if currentUser == nil {
+ server.handleNotFound(w, r)
+ return
+ }
+
+ timer := queryTimer(server.db, r.PathValue("timerId"), currentUser.Id)
+ if timer != nil && timer.Owner == currentUser.Id {
view.Main(view.TimerView(*timer), currentUser).Render(context.Background(), w)
} else {
server.handleNotFound(w, r)
@@ -275,7 +274,13 @@ func parseDuration(value string) (time.Duration, error) {
}
func (server *MyServer) handleTimerAddTime(w http.ResponseWriter, r *http.Request) {
- timer := queryTimer(server.db, r.PathValue("timerId"))
+ currentUser := server.findCurrentUser(w, r)
+ if currentUser == nil {
+ w.WriteHeader(http.StatusUnauthorized)
+ return
+ }
+
+ timer := queryTimer(server.db, r.PathValue("timerId"), currentUser.Id)
if timer == nil {
server.handleNotFound(w, r)
return
@@ -294,7 +299,7 @@ func (server *MyServer) handleTimerAddTime(w http.ResponseWriter, r *http.Reques
}
timer.EndTime.Add(duration)
- res := updateTimerEndTime(server.db, timer.Id, timer.EndTime)
+ res := updateTimerEndTime(server.db, timer.Id, timer.EndTime, currentUser.Id)
if !res {
w.WriteHeader(http.StatusBadRequest)
return
@@ -304,7 +309,13 @@ func (server *MyServer) handleTimerAddTime(w http.ResponseWriter, r *http.Reques
}
func (server *MyServer) handleDeleteTimer(w http.ResponseWriter, r *http.Request) {
- success := deleteTimer(server.db, r.PathValue("timerId"))
+ user := server.findCurrentUser(w, r)
+ if user == nil {
+ w.WriteHeader(http.StatusUnauthorized)
+ return
+ }
+
+ success := deleteTimer(server.db, r.PathValue("timerId"), user.Id)
if !success {
w.WriteHeader(http.StatusNotFound)
}
@@ -313,6 +324,13 @@ func (server *MyServer) handleDeleteTimer(w http.ResponseWriter, r *http.Request
func (server *MyServer) handlePutTimer(w http.ResponseWriter, r *http.Request) {
timerName := strings.TrimSpace(r.FormValue("timerName"))
+ user := server.findCurrentUser(w, r)
+ if user == nil {
+ w.WriteHeader(http.StatusBadRequest)
+ view.TimerCreateForm(timerName, "You are not signed in").Render(context.Background(), w)
+ return
+ }
+
days, err := strconv.ParseInt(strings.TrimSpace(r.FormValue("days")), 10, 32)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
@@ -341,7 +359,7 @@ func (server *MyServer) handlePutTimer(w http.ResponseWriter, r *http.Request) {
return
}
- err = insertTimer(tx, timerName, int(((max(days, 0)*24)+max(hours, 0))*3600))
+ err = insertTimer(tx, timerName, int(((max(days, 0)*24)+max(hours, 0))*3600), user.Id)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
view.TimerCreateForm(timerName, "Internal server error").Render(context.Background(), w)
@@ -350,8 +368,17 @@ func (server *MyServer) handlePutTimer(w http.ResponseWriter, r *http.Request) {
tx.Commit()
- timers := queryAllTimers(server.db)
- view.TimersList(timers).Render(context.Background(), w)
+ timers := queryAllTimers(server.db, user.Id)
+ view.TimersList(timers, user != nil).Render(context.Background(), w)
+}
+
+func generateSessionId() (string, error) {
+ bin := make([]byte, 64)
+ _, err := rand.Read(bin)
+ if err != nil {
+ return "", err
+ }
+ return base64.StdEncoding.EncodeToString(bin), nil
}
func (server *MyServer) handlePostLogin(w http.ResponseWriter, r *http.Request) {
@@ -377,7 +404,7 @@ func (server *MyServer) handlePostLogin(w http.ResponseWriter, r *http.Request)
return
}
- sessionId, err := uuid.NewRandom()
+ sessionId, err := generateSessionId()
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
view.LoginFormError(nil, "Internal server error").Render(context.Background(), w)
@@ -386,11 +413,11 @@ func (server *MyServer) handlePostLogin(w http.ResponseWriter, r *http.Request)
cookie := http.Cookie{
Name: SessionCookieName,
- Value: sessionId.String(),
+ Value: sessionId,
HttpOnly: true,
Secure: true,
}
- server.sessions[sessionId.String()] = Session{UserId: user.Id}
+ server.sessions[sessionId] = Session{UserId: user.Id}
http.SetCookie(w, &cookie)
w.Header().Add("HX-Redirect", "/")
}