Protect timer operations with session

This commit is contained in:
2024-04-15 23:59:00 +02:00
parent 3d507d36ea
commit 9da246e91e
6 changed files with 86 additions and 54 deletions

View File

@ -5,6 +5,7 @@ type Timer struct {
Name string Name string
StartTime Time StartTime Time
EndTime Time EndTime Time
Owner UUID
} }
func (self Timer) IsFinished() bool { func (self Timer) IsFinished() bool {

105
timer.go
View File

@ -2,7 +2,9 @@ package main
import ( import (
"context" "context"
"crypto/rand"
"database/sql" "database/sql"
"encoding/base64"
"errors" "errors"
"log" "log"
"net/http" "net/http"
@ -12,20 +14,18 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/google/uuid"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"stevenlr.com/timer/model" "stevenlr.com/timer/model"
"stevenlr.com/timer/view" "stevenlr.com/timer/view"
) )
func insertTimer(tx *sql.Tx, name string, seconds int) error { func insertTimer(tx *sql.Tx, name string, seconds int, ownerId model.UUID) error {
now := model.MakeTimeNow() now := model.MakeTimeNow()
end := model.Time(time.Time(now).Add(time.Duration(seconds) * time.Second)) end := model.Time(time.Time(now).Add(time.Duration(seconds) * time.Second))
id := model.MakeUUID() id := model.MakeUUID()
_, err := tx.Exec(` _, err := tx.Exec(`
INSERT INTO Timer VALUES ($1, $2, $3, $4)`, id, name, now, end) INSERT INTO Timer VALUES ($1, $2, $3, $4, $5)`, id, name, now, end, ownerId)
return err return err
} }
@ -47,18 +47,16 @@ func initializeDatabase(db *sql.DB) error {
Name TEXT NOT NULL, Name TEXT NOT NULL,
StartTime TEXT NOT NULL, StartTime TEXT NOT NULL,
EndTime TEXT NOT NULL, EndTime TEXT NOT NULL,
Owner BLOB NOT NULL,
PRIMARY KEY (id) PRIMARY KEY (id)
)`) )`)
if err != nil { if err != nil {
return err return err
} }
err = insertTimer(tx, "My timer", 600) userUuidStr := "7015cee7-89a5-4057-b7c9-7e0128ad5086"
if err != nil { var userId model.UUID
return err err = userId.Scan(userUuidStr)
}
err = insertTimer(tx, "My timer2", 600)
if err != nil { if err != nil {
return err return err
} }
@ -75,13 +73,6 @@ func initializeDatabase(db *sql.DB) error {
return err return err
} }
userUuidStr := "7015cee7-89a5-4057-b7c9-7e0128ad5086"
var userId model.UUID
err = userId.Scan(userUuidStr)
if err != nil {
return err
}
userPasswordClear := "steven" userPasswordClear := "steven"
password, err := bcrypt.GenerateFromPassword([]byte(userUuidStr+userPasswordClear), bcrypt.MinCost) password, err := bcrypt.GenerateFromPassword([]byte(userUuidStr+userPasswordClear), bcrypt.MinCost)
@ -97,8 +88,8 @@ func initializeDatabase(db *sql.DB) error {
return tx.Commit() return tx.Commit()
} }
func queryAllTimers(db *sql.DB) []model.Timer { func queryAllTimers(db *sql.DB, owner model.UUID) []model.Timer {
rows, err := db.Query("SELECT Id, Name FROM Timer") rows, err := db.Query("SELECT Id, Name FROM Timer WHERE Owner=$1", owner)
if err != nil { if err != nil {
log.Fatalln(err) log.Fatalln(err)
} }
@ -138,29 +129,29 @@ func queryUserById(db *sql.DB, id model.UUID) *model.User {
return &user return &user
} }
func queryTimer(db *sql.DB, idStr string) *model.Timer { func queryTimer(db *sql.DB, idStr string, userId model.UUID) *model.Timer {
var id model.UUID var id model.UUID
if err := id.Scan(idStr); err != nil { if err := id.Scan(idStr); err != nil {
return nil return nil
} }
row := db.QueryRow("SELECT Id, Name, StartTime, EndTime FROM Timer WHERE Id=$1", id) row := db.QueryRow("SELECT Id, Name, StartTime, EndTime, Owner FROM Timer WHERE Id=$1 AND Owner=$2", id, userId)
var t model.Timer var t model.Timer
if err := row.Scan(&t.Id, &t.Name, &t.StartTime, &t.EndTime); err == nil { if err := row.Scan(&t.Id, &t.Name, &t.StartTime, &t.EndTime, &t.Owner); err == nil {
return &t return &t
} }
return nil return nil
} }
func deleteTimer(db *sql.DB, idStr string) bool { func deleteTimer(db *sql.DB, idStr string, userId model.UUID) bool {
var id model.UUID var id model.UUID
if err := id.Scan(idStr); err != nil { if err := id.Scan(idStr); err != nil {
return false return false
} }
res, err := db.Exec("DELETE FROM Timer WHERE Id=$1", id) res, err := db.Exec("DELETE FROM Timer WHERE Id=$1 AND Owner=$2", id, userId)
if err != nil { if err != nil {
return false return false
} }
@ -169,8 +160,8 @@ func deleteTimer(db *sql.DB, idStr string) bool {
return err == nil && affected == 1 return err == nil && affected == 1
} }
func updateTimerEndTime(db *sql.DB, id model.UUID, endTime model.Time) bool { func updateTimerEndTime(db *sql.DB, id model.UUID, endTime model.Time, userId model.UUID) bool {
res, err := db.Exec("UPDATE Timer SET EndTime=$1 WHERE Id=$2", endTime, id) res, err := db.Exec("UPDATE Timer SET EndTime=$1 WHERE Id=$2 AND Owner=$3", endTime, id, userId)
if err != nil { if err != nil {
return false return false
} }
@ -227,8 +218,11 @@ func (server *MyServer) handleNotFound(w http.ResponseWriter, _ *http.Request) {
func (server *MyServer) handleMain(w http.ResponseWriter, r *http.Request) { func (server *MyServer) handleMain(w http.ResponseWriter, r *http.Request) {
currentUser := server.findCurrentUser(w, r) currentUser := server.findCurrentUser(w, r)
if r.URL.Path == "/" { if r.URL.Path == "/" {
timers := queryAllTimers(server.db) timers := make([]model.Timer, 0)
view.Main(view.TimersList(timers), currentUser).Render(context.Background(), w) if currentUser != nil {
timers = queryAllTimers(server.db, currentUser.Id)
}
view.Main(view.TimersList(timers, currentUser != nil), currentUser).Render(context.Background(), w)
} else { } else {
server.handleNotFound(w, r) server.handleNotFound(w, r)
} }
@ -236,8 +230,13 @@ func (server *MyServer) handleMain(w http.ResponseWriter, r *http.Request) {
func (server *MyServer) handleTimer(w http.ResponseWriter, r *http.Request) { func (server *MyServer) handleTimer(w http.ResponseWriter, r *http.Request) {
currentUser := server.findCurrentUser(w, r) currentUser := server.findCurrentUser(w, r)
timer := queryTimer(server.db, r.PathValue("timerId")) if currentUser == nil {
if timer != nil { server.handleNotFound(w, r)
return
}
timer := queryTimer(server.db, r.PathValue("timerId"), currentUser.Id)
if timer != nil && timer.Owner == currentUser.Id {
view.Main(view.TimerView(*timer), currentUser).Render(context.Background(), w) view.Main(view.TimerView(*timer), currentUser).Render(context.Background(), w)
} else { } else {
server.handleNotFound(w, r) server.handleNotFound(w, r)
@ -275,7 +274,13 @@ func parseDuration(value string) (time.Duration, error) {
} }
func (server *MyServer) handleTimerAddTime(w http.ResponseWriter, r *http.Request) { func (server *MyServer) handleTimerAddTime(w http.ResponseWriter, r *http.Request) {
timer := queryTimer(server.db, r.PathValue("timerId")) currentUser := server.findCurrentUser(w, r)
if currentUser == nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
timer := queryTimer(server.db, r.PathValue("timerId"), currentUser.Id)
if timer == nil { if timer == nil {
server.handleNotFound(w, r) server.handleNotFound(w, r)
return return
@ -294,7 +299,7 @@ func (server *MyServer) handleTimerAddTime(w http.ResponseWriter, r *http.Reques
} }
timer.EndTime.Add(duration) timer.EndTime.Add(duration)
res := updateTimerEndTime(server.db, timer.Id, timer.EndTime) res := updateTimerEndTime(server.db, timer.Id, timer.EndTime, currentUser.Id)
if !res { if !res {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
return return
@ -304,7 +309,13 @@ func (server *MyServer) handleTimerAddTime(w http.ResponseWriter, r *http.Reques
} }
func (server *MyServer) handleDeleteTimer(w http.ResponseWriter, r *http.Request) { func (server *MyServer) handleDeleteTimer(w http.ResponseWriter, r *http.Request) {
success := deleteTimer(server.db, r.PathValue("timerId")) user := server.findCurrentUser(w, r)
if user == nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
success := deleteTimer(server.db, r.PathValue("timerId"), user.Id)
if !success { if !success {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
} }
@ -313,6 +324,13 @@ func (server *MyServer) handleDeleteTimer(w http.ResponseWriter, r *http.Request
func (server *MyServer) handlePutTimer(w http.ResponseWriter, r *http.Request) { func (server *MyServer) handlePutTimer(w http.ResponseWriter, r *http.Request) {
timerName := strings.TrimSpace(r.FormValue("timerName")) timerName := strings.TrimSpace(r.FormValue("timerName"))
user := server.findCurrentUser(w, r)
if user == nil {
w.WriteHeader(http.StatusBadRequest)
view.TimerCreateForm(timerName, "You are not signed in").Render(context.Background(), w)
return
}
days, err := strconv.ParseInt(strings.TrimSpace(r.FormValue("days")), 10, 32) days, err := strconv.ParseInt(strings.TrimSpace(r.FormValue("days")), 10, 32)
if err != nil { if err != nil {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
@ -341,7 +359,7 @@ func (server *MyServer) handlePutTimer(w http.ResponseWriter, r *http.Request) {
return return
} }
err = insertTimer(tx, timerName, int(((max(days, 0)*24)+max(hours, 0))*3600)) err = insertTimer(tx, timerName, int(((max(days, 0)*24)+max(hours, 0))*3600), user.Id)
if err != nil { if err != nil {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
view.TimerCreateForm(timerName, "Internal server error").Render(context.Background(), w) view.TimerCreateForm(timerName, "Internal server error").Render(context.Background(), w)
@ -350,8 +368,17 @@ func (server *MyServer) handlePutTimer(w http.ResponseWriter, r *http.Request) {
tx.Commit() tx.Commit()
timers := queryAllTimers(server.db) timers := queryAllTimers(server.db, user.Id)
view.TimersList(timers).Render(context.Background(), w) view.TimersList(timers, user != nil).Render(context.Background(), w)
}
func generateSessionId() (string, error) {
bin := make([]byte, 64)
_, err := rand.Read(bin)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(bin), nil
} }
func (server *MyServer) handlePostLogin(w http.ResponseWriter, r *http.Request) { func (server *MyServer) handlePostLogin(w http.ResponseWriter, r *http.Request) {
@ -377,7 +404,7 @@ func (server *MyServer) handlePostLogin(w http.ResponseWriter, r *http.Request)
return return
} }
sessionId, err := uuid.NewRandom() sessionId, err := generateSessionId()
if err != nil { if err != nil {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
view.LoginFormError(nil, "Internal server error").Render(context.Background(), w) view.LoginFormError(nil, "Internal server error").Render(context.Background(), w)
@ -386,11 +413,11 @@ func (server *MyServer) handlePostLogin(w http.ResponseWriter, r *http.Request)
cookie := http.Cookie{ cookie := http.Cookie{
Name: SessionCookieName, Name: SessionCookieName,
Value: sessionId.String(), Value: sessionId,
HttpOnly: true, HttpOnly: true,
Secure: true, Secure: true,
} }
server.sessions[sessionId.String()] = Session{UserId: user.Id} server.sessions[sessionId] = Session{UserId: user.Id}
http.SetCookie(w, &cookie) http.SetCookie(w, &cookie)
w.Header().Add("HX-Redirect", "/") w.Header().Add("HX-Redirect", "/")
} }

View File

@ -7,7 +7,7 @@ import (
templ TimerView(timer model.Timer) { templ TimerView(timer model.Timer) {
<div class="timer"> <div class="timer">
<h1>This is timer { timer.Name } </h1> <h1>Timer "{ timer.Name }"</h1>
<p><a href="/">Back to list</a></p> <p><a href="/">Back to list</a></p>
<p>Start time: <local-date>{ timer.StartTime.AsUTCString() }</local-date></p> <p>Start time: <local-date>{ timer.StartTime.AsUTCString() }</local-date></p>
<p>End time: <local-date>{ timer.EndTime.AsUTCString() }</local-date></p> <p>End time: <local-date>{ timer.EndTime.AsUTCString() }</local-date></p>

View File

@ -28,20 +28,20 @@ func TimerView(timer model.Timer) templ.Component {
templ_7745c5c3_Var1 = templ.NopComponent templ_7745c5c3_Var1 = templ.NopComponent
} }
ctx = templ.ClearChildren(ctx) ctx = templ.ClearChildren(ctx)
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("<div class=\"timer\"><h1>This is timer ") _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("<div class=\"timer\"><h1>Timer \"")
if templ_7745c5c3_Err != nil { if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err return templ_7745c5c3_Err
} }
var templ_7745c5c3_Var2 string var templ_7745c5c3_Var2 string
templ_7745c5c3_Var2, templ_7745c5c3_Err = templ.JoinStringErrs(timer.Name) templ_7745c5c3_Var2, templ_7745c5c3_Err = templ.JoinStringErrs(timer.Name)
if templ_7745c5c3_Err != nil { if templ_7745c5c3_Err != nil {
return templ.Error{Err: templ_7745c5c3_Err, FileName: `view\timer.templ`, Line: 10, Col: 32} return templ.Error{Err: templ_7745c5c3_Err, FileName: `view\timer.templ`, Line: 10, Col: 25}
} }
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var2)) _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var2))
if templ_7745c5c3_Err != nil { if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err return templ_7745c5c3_Err
} }
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("</h1><p><a href=\"/\">Back to list</a></p><p>Start time: <local-date>") _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("\"</h1><p><a href=\"/\">Back to list</a></p><p>Start time: <local-date>")
if templ_7745c5c3_Err != nil { if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err return templ_7745c5c3_Err
} }

View File

@ -35,14 +35,16 @@ templ TimerCreateForm(timerName string, err string) {
</form> </form>
} }
templ TimersList(timers []model.Timer) { templ TimersList(timers []model.Timer, isSignedIn bool) {
<div class="timers-list"> <div class="timers-list">
<h1>Timers</h1> <h1>Timers</h1>
for _, t := range timers { for _, t := range timers {
@timer(t) @timer(t)
} }
<h4>Create timer</h4> if isSignedIn {
@TimerCreateForm("", "") <h4>Create timer</h4>
@TimerCreateForm("", "")
}
</div> </div>
} }

View File

@ -134,7 +134,7 @@ func TimerCreateForm(timerName string, err string) templ.Component {
}) })
} }
func TimersList(timers []model.Timer) templ.Component { func TimersList(timers []model.Timer, isSignedIn bool) templ.Component {
return templ.ComponentFunc(func(ctx context.Context, templ_7745c5c3_W io.Writer) (templ_7745c5c3_Err error) { return templ.ComponentFunc(func(ctx context.Context, templ_7745c5c3_W io.Writer) (templ_7745c5c3_Err error) {
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templ_7745c5c3_W.(*bytes.Buffer) templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templ_7745c5c3_W.(*bytes.Buffer)
if !templ_7745c5c3_IsBuffer { if !templ_7745c5c3_IsBuffer {
@ -157,13 +157,15 @@ func TimersList(timers []model.Timer) templ.Component {
return templ_7745c5c3_Err return templ_7745c5c3_Err
} }
} }
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("<h4>Create timer</h4>") if isSignedIn {
if templ_7745c5c3_Err != nil { _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("<h4>Create timer</h4>")
return templ_7745c5c3_Err if templ_7745c5c3_Err != nil {
} return templ_7745c5c3_Err
templ_7745c5c3_Err = TimerCreateForm("", "").Render(ctx, templ_7745c5c3_Buffer) }
if templ_7745c5c3_Err != nil { templ_7745c5c3_Err = TimerCreateForm("", "").Render(ctx, templ_7745c5c3_Buffer)
return templ_7745c5c3_Err if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
} }
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("</div>") _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("</div>")
if templ_7745c5c3_Err != nil { if templ_7745c5c3_Err != nil {