Protect timer operations with session

This commit is contained in:
2024-04-15 23:59:00 +02:00
parent 3d507d36ea
commit 9da246e91e
6 changed files with 86 additions and 54 deletions

View File

@ -5,6 +5,7 @@ type Timer struct {
Name string
StartTime Time
EndTime Time
Owner UUID
}
func (self Timer) IsFinished() bool {

105
timer.go
View File

@ -2,7 +2,9 @@ package main
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"errors"
"log"
"net/http"
@ -12,20 +14,18 @@ import (
"golang.org/x/crypto/bcrypt"
"github.com/google/uuid"
_ "github.com/mattn/go-sqlite3"
"stevenlr.com/timer/model"
"stevenlr.com/timer/view"
)
func insertTimer(tx *sql.Tx, name string, seconds int) error {
func insertTimer(tx *sql.Tx, name string, seconds int, ownerId model.UUID) error {
now := model.MakeTimeNow()
end := model.Time(time.Time(now).Add(time.Duration(seconds) * time.Second))
id := model.MakeUUID()
_, err := tx.Exec(`
INSERT INTO Timer VALUES ($1, $2, $3, $4)`, id, name, now, end)
INSERT INTO Timer VALUES ($1, $2, $3, $4, $5)`, id, name, now, end, ownerId)
return err
}
@ -47,18 +47,16 @@ func initializeDatabase(db *sql.DB) error {
Name TEXT NOT NULL,
StartTime TEXT NOT NULL,
EndTime TEXT NOT NULL,
Owner BLOB NOT NULL,
PRIMARY KEY (id)
)`)
if err != nil {
return err
}
err = insertTimer(tx, "My timer", 600)
if err != nil {
return err
}
err = insertTimer(tx, "My timer2", 600)
userUuidStr := "7015cee7-89a5-4057-b7c9-7e0128ad5086"
var userId model.UUID
err = userId.Scan(userUuidStr)
if err != nil {
return err
}
@ -75,13 +73,6 @@ func initializeDatabase(db *sql.DB) error {
return err
}
userUuidStr := "7015cee7-89a5-4057-b7c9-7e0128ad5086"
var userId model.UUID
err = userId.Scan(userUuidStr)
if err != nil {
return err
}
userPasswordClear := "steven"
password, err := bcrypt.GenerateFromPassword([]byte(userUuidStr+userPasswordClear), bcrypt.MinCost)
@ -97,8 +88,8 @@ func initializeDatabase(db *sql.DB) error {
return tx.Commit()
}
func queryAllTimers(db *sql.DB) []model.Timer {
rows, err := db.Query("SELECT Id, Name FROM Timer")
func queryAllTimers(db *sql.DB, owner model.UUID) []model.Timer {
rows, err := db.Query("SELECT Id, Name FROM Timer WHERE Owner=$1", owner)
if err != nil {
log.Fatalln(err)
}
@ -138,29 +129,29 @@ func queryUserById(db *sql.DB, id model.UUID) *model.User {
return &user
}
func queryTimer(db *sql.DB, idStr string) *model.Timer {
func queryTimer(db *sql.DB, idStr string, userId model.UUID) *model.Timer {
var id model.UUID
if err := id.Scan(idStr); err != nil {
return nil
}
row := db.QueryRow("SELECT Id, Name, StartTime, EndTime FROM Timer WHERE Id=$1", id)
row := db.QueryRow("SELECT Id, Name, StartTime, EndTime, Owner FROM Timer WHERE Id=$1 AND Owner=$2", id, userId)
var t model.Timer
if err := row.Scan(&t.Id, &t.Name, &t.StartTime, &t.EndTime); err == nil {
if err := row.Scan(&t.Id, &t.Name, &t.StartTime, &t.EndTime, &t.Owner); err == nil {
return &t
}
return nil
}
func deleteTimer(db *sql.DB, idStr string) bool {
func deleteTimer(db *sql.DB, idStr string, userId model.UUID) bool {
var id model.UUID
if err := id.Scan(idStr); err != nil {
return false
}
res, err := db.Exec("DELETE FROM Timer WHERE Id=$1", id)
res, err := db.Exec("DELETE FROM Timer WHERE Id=$1 AND Owner=$2", id, userId)
if err != nil {
return false
}
@ -169,8 +160,8 @@ func deleteTimer(db *sql.DB, idStr string) bool {
return err == nil && affected == 1
}
func updateTimerEndTime(db *sql.DB, id model.UUID, endTime model.Time) bool {
res, err := db.Exec("UPDATE Timer SET EndTime=$1 WHERE Id=$2", endTime, id)
func updateTimerEndTime(db *sql.DB, id model.UUID, endTime model.Time, userId model.UUID) bool {
res, err := db.Exec("UPDATE Timer SET EndTime=$1 WHERE Id=$2 AND Owner=$3", endTime, id, userId)
if err != nil {
return false
}
@ -227,8 +218,11 @@ func (server *MyServer) handleNotFound(w http.ResponseWriter, _ *http.Request) {
func (server *MyServer) handleMain(w http.ResponseWriter, r *http.Request) {
currentUser := server.findCurrentUser(w, r)
if r.URL.Path == "/" {
timers := queryAllTimers(server.db)
view.Main(view.TimersList(timers), currentUser).Render(context.Background(), w)
timers := make([]model.Timer, 0)
if currentUser != nil {
timers = queryAllTimers(server.db, currentUser.Id)
}
view.Main(view.TimersList(timers, currentUser != nil), currentUser).Render(context.Background(), w)
} else {
server.handleNotFound(w, r)
}
@ -236,8 +230,13 @@ func (server *MyServer) handleMain(w http.ResponseWriter, r *http.Request) {
func (server *MyServer) handleTimer(w http.ResponseWriter, r *http.Request) {
currentUser := server.findCurrentUser(w, r)
timer := queryTimer(server.db, r.PathValue("timerId"))
if timer != nil {
if currentUser == nil {
server.handleNotFound(w, r)
return
}
timer := queryTimer(server.db, r.PathValue("timerId"), currentUser.Id)
if timer != nil && timer.Owner == currentUser.Id {
view.Main(view.TimerView(*timer), currentUser).Render(context.Background(), w)
} else {
server.handleNotFound(w, r)
@ -275,7 +274,13 @@ func parseDuration(value string) (time.Duration, error) {
}
func (server *MyServer) handleTimerAddTime(w http.ResponseWriter, r *http.Request) {
timer := queryTimer(server.db, r.PathValue("timerId"))
currentUser := server.findCurrentUser(w, r)
if currentUser == nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
timer := queryTimer(server.db, r.PathValue("timerId"), currentUser.Id)
if timer == nil {
server.handleNotFound(w, r)
return
@ -294,7 +299,7 @@ func (server *MyServer) handleTimerAddTime(w http.ResponseWriter, r *http.Reques
}
timer.EndTime.Add(duration)
res := updateTimerEndTime(server.db, timer.Id, timer.EndTime)
res := updateTimerEndTime(server.db, timer.Id, timer.EndTime, currentUser.Id)
if !res {
w.WriteHeader(http.StatusBadRequest)
return
@ -304,7 +309,13 @@ func (server *MyServer) handleTimerAddTime(w http.ResponseWriter, r *http.Reques
}
func (server *MyServer) handleDeleteTimer(w http.ResponseWriter, r *http.Request) {
success := deleteTimer(server.db, r.PathValue("timerId"))
user := server.findCurrentUser(w, r)
if user == nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
success := deleteTimer(server.db, r.PathValue("timerId"), user.Id)
if !success {
w.WriteHeader(http.StatusNotFound)
}
@ -313,6 +324,13 @@ func (server *MyServer) handleDeleteTimer(w http.ResponseWriter, r *http.Request
func (server *MyServer) handlePutTimer(w http.ResponseWriter, r *http.Request) {
timerName := strings.TrimSpace(r.FormValue("timerName"))
user := server.findCurrentUser(w, r)
if user == nil {
w.WriteHeader(http.StatusBadRequest)
view.TimerCreateForm(timerName, "You are not signed in").Render(context.Background(), w)
return
}
days, err := strconv.ParseInt(strings.TrimSpace(r.FormValue("days")), 10, 32)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
@ -341,7 +359,7 @@ func (server *MyServer) handlePutTimer(w http.ResponseWriter, r *http.Request) {
return
}
err = insertTimer(tx, timerName, int(((max(days, 0)*24)+max(hours, 0))*3600))
err = insertTimer(tx, timerName, int(((max(days, 0)*24)+max(hours, 0))*3600), user.Id)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
view.TimerCreateForm(timerName, "Internal server error").Render(context.Background(), w)
@ -350,8 +368,17 @@ func (server *MyServer) handlePutTimer(w http.ResponseWriter, r *http.Request) {
tx.Commit()
timers := queryAllTimers(server.db)
view.TimersList(timers).Render(context.Background(), w)
timers := queryAllTimers(server.db, user.Id)
view.TimersList(timers, user != nil).Render(context.Background(), w)
}
func generateSessionId() (string, error) {
bin := make([]byte, 64)
_, err := rand.Read(bin)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(bin), nil
}
func (server *MyServer) handlePostLogin(w http.ResponseWriter, r *http.Request) {
@ -377,7 +404,7 @@ func (server *MyServer) handlePostLogin(w http.ResponseWriter, r *http.Request)
return
}
sessionId, err := uuid.NewRandom()
sessionId, err := generateSessionId()
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
view.LoginFormError(nil, "Internal server error").Render(context.Background(), w)
@ -386,11 +413,11 @@ func (server *MyServer) handlePostLogin(w http.ResponseWriter, r *http.Request)
cookie := http.Cookie{
Name: SessionCookieName,
Value: sessionId.String(),
Value: sessionId,
HttpOnly: true,
Secure: true,
}
server.sessions[sessionId.String()] = Session{UserId: user.Id}
server.sessions[sessionId] = Session{UserId: user.Id}
http.SetCookie(w, &cookie)
w.Header().Add("HX-Redirect", "/")
}

View File

@ -7,7 +7,7 @@ import (
templ TimerView(timer model.Timer) {
<div class="timer">
<h1>This is timer { timer.Name } </h1>
<h1>Timer "{ timer.Name }"</h1>
<p><a href="/">Back to list</a></p>
<p>Start time: <local-date>{ timer.StartTime.AsUTCString() }</local-date></p>
<p>End time: <local-date>{ timer.EndTime.AsUTCString() }</local-date></p>

View File

@ -28,20 +28,20 @@ func TimerView(timer model.Timer) templ.Component {
templ_7745c5c3_Var1 = templ.NopComponent
}
ctx = templ.ClearChildren(ctx)
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("<div class=\"timer\"><h1>This is timer ")
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("<div class=\"timer\"><h1>Timer \"")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var2 string
templ_7745c5c3_Var2, templ_7745c5c3_Err = templ.JoinStringErrs(timer.Name)
if templ_7745c5c3_Err != nil {
return templ.Error{Err: templ_7745c5c3_Err, FileName: `view\timer.templ`, Line: 10, Col: 32}
return templ.Error{Err: templ_7745c5c3_Err, FileName: `view\timer.templ`, Line: 10, Col: 25}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var2))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("</h1><p><a href=\"/\">Back to list</a></p><p>Start time: <local-date>")
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("\"</h1><p><a href=\"/\">Back to list</a></p><p>Start time: <local-date>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}

View File

@ -35,14 +35,16 @@ templ TimerCreateForm(timerName string, err string) {
</form>
}
templ TimersList(timers []model.Timer) {
templ TimersList(timers []model.Timer, isSignedIn bool) {
<div class="timers-list">
<h1>Timers</h1>
for _, t := range timers {
@timer(t)
}
<h4>Create timer</h4>
@TimerCreateForm("", "")
if isSignedIn {
<h4>Create timer</h4>
@TimerCreateForm("", "")
}
</div>
}

View File

@ -134,7 +134,7 @@ func TimerCreateForm(timerName string, err string) templ.Component {
})
}
func TimersList(timers []model.Timer) templ.Component {
func TimersList(timers []model.Timer, isSignedIn bool) templ.Component {
return templ.ComponentFunc(func(ctx context.Context, templ_7745c5c3_W io.Writer) (templ_7745c5c3_Err error) {
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templ_7745c5c3_W.(*bytes.Buffer)
if !templ_7745c5c3_IsBuffer {
@ -157,13 +157,15 @@ func TimersList(timers []model.Timer) templ.Component {
return templ_7745c5c3_Err
}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("<h4>Create timer</h4>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = TimerCreateForm("", "").Render(ctx, templ_7745c5c3_Buffer)
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
if isSignedIn {
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("<h4>Create timer</h4>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = TimerCreateForm("", "").Render(ctx, templ_7745c5c3_Buffer)
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("</div>")
if templ_7745c5c3_Err != nil {