summaryrefslogtreecommitdiff
path: root/pkg/middleware
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/middleware')
-rw-r--r--pkg/middleware/context.go21
-rw-r--r--pkg/middleware/session/store.go232
-rw-r--r--pkg/middleware/state.go58
3 files changed, 311 insertions, 0 deletions
diff --git a/pkg/middleware/context.go b/pkg/middleware/context.go
new file mode 100644
index 0000000..f9c4425
--- /dev/null
+++ b/pkg/middleware/context.go
@@ -0,0 +1,21 @@
+package middleware
+
+import (
+ "context"
+
+ "github.com/labstack/echo/v4"
+)
+
+type ContextKey string
+
+func Context(key ContextKey, value any) echo.MiddlewareFunc {
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ ctx := context.WithValue(c.Request().Context(), key, value)
+ r := c.Request().WithContext(ctx)
+ c.SetRequest(r)
+
+ return next(c)
+ }
+ }
+}
diff --git a/pkg/middleware/session/store.go b/pkg/middleware/session/store.go
new file mode 100644
index 0000000..78172d7
--- /dev/null
+++ b/pkg/middleware/session/store.go
@@ -0,0 +1,232 @@
+package session
+
+import (
+ "context"
+ "encoding/base32"
+ "log/slog"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/gorilla/securecookie"
+ "github.com/gorilla/sessions"
+ "github.com/uptrace/bun"
+)
+
+const (
+ sessionIDLen = 32
+ defaultTableName = "sessions"
+ defaultMaxAge = 60 * 60 * 24 * 30 // 30 days
+ defaultPath = "/"
+)
+
+// Options for bunstore.
+type Options struct {
+ TableName string
+ SkipCreateTable bool
+}
+
+// Store represent a bunstore.
+type Store struct {
+ db *bun.DB
+ opts Options
+ Codecs []securecookie.Codec
+ SessionOpts *sessions.Options
+}
+
+type Model struct {
+ bun.BaseModel `bun:"table:sessions,alias:s"`
+
+ ID string `bun:",pk,unique"`
+ Data string
+ CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
+ UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
+ ExpiresAt time.Time
+}
+
+type KeyPairs []string
+
+func (k KeyPairs) ToKeys() [][]byte {
+ b := make([][]byte, 0, len(k))
+ for _, kk := range k {
+ b = append(b, []byte(kk))
+ }
+
+ return b
+}
+
+// New creates a new bunstore session.
+func New(db *bun.DB, keyPairs KeyPairs) (*Store, error) {
+ return NewOptions(db, Options{}, keyPairs)
+}
+
+// NewOptions creates a new bunstore session with options.
+func NewOptions(db *bun.DB, opts Options, keyPairs KeyPairs) (*Store, error) {
+ st := &Store{
+ db: db,
+ opts: opts,
+ Codecs: securecookie.CodecsFromPairs(keyPairs.ToKeys()...),
+ SessionOpts: &sessions.Options{
+ Path: defaultPath,
+ MaxAge: defaultMaxAge,
+ },
+ }
+
+ return st, nil
+}
+
+// Get returns a session for the given name after adding it to the registry.
+func (st *Store) Get(r *http.Request, name string) (*sessions.Session, error) {
+ return sessions.GetRegistry(r).Get(st, name)
+}
+
+// New creates a session with name without adding it to the registry.
+func (st *Store) New(r *http.Request, name string) (*sessions.Session, error) {
+ session := sessions.NewSession(st, name)
+ opts := *st.SessionOpts
+ session.Options = &opts
+ session.IsNew = true
+
+ st.MaxAge(st.SessionOpts.MaxAge)
+
+ // try fetch from db if there is a cookie
+ s := st.getSessionFromCookie(r, session.Name())
+ if s != nil {
+ if err := securecookie.DecodeMulti(session.Name(), s.Data, &session.Values, st.Codecs...); err != nil {
+ //nolint:nilerr
+ return session, nil
+ }
+
+ session.ID = s.ID
+ session.IsNew = false
+ }
+
+ return session, nil
+}
+
+// Save session and set cookie header.
+func (st *Store) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
+ s := st.getSessionFromCookie(r, session.Name())
+
+ // delete if max age is < 0
+ if session.Options.MaxAge < 0 {
+ if s != nil {
+ if _, err := st.db.NewDelete().Model(&Model{ID: session.ID}).WherePK("id").Exec(r.Context()); err != nil {
+ return err
+ }
+ }
+
+ http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
+
+ return nil
+ }
+
+ data, err := securecookie.EncodeMulti(session.Name(), session.Values, st.Codecs...)
+ if err != nil {
+ return err
+ }
+
+ now := time.Now()
+ expire := now.Add(time.Second * time.Duration(session.Options.MaxAge))
+
+ if s == nil {
+ // generate random session ID key suitable for storage in the db
+ session.ID = strings.TrimRight(
+ base32.StdEncoding.EncodeToString(
+ securecookie.GenerateRandomKey(sessionIDLen)), "=")
+ s = &Model{
+ ID: session.ID,
+ Data: data,
+ ExpiresAt: expire,
+ }
+
+ if _, err := st.db.NewInsert().Model(s).Exec(r.Context()); err != nil {
+ return err
+ }
+ } else {
+ s.Data = data
+ s.ExpiresAt = expire
+
+ if _, err := st.db.NewUpdate().Model(s).WherePK("id").Column("data", "expires_at").Exec(r.Context()); err != nil {
+ return err
+ }
+ }
+
+ // set session id cookie
+ id, err := securecookie.EncodeMulti(session.Name(), s.ID, st.Codecs...)
+ if err != nil {
+ return err
+ }
+
+ http.SetCookie(w, sessions.NewCookie(session.Name(), id, session.Options))
+
+ return nil
+}
+
+// getSessionFromCookie looks for an existing bunSession from a session ID stored inside a cookie.
+func (st *Store) getSessionFromCookie(r *http.Request, name string) *Model {
+ if cookie, err := r.Cookie(name); err == nil {
+ sessionID := ""
+ if err := securecookie.DecodeMulti(name, cookie.Value, &sessionID, st.Codecs...); err != nil {
+ return nil
+ }
+
+ s := &Model{}
+ if err := st.db.NewSelect().
+ Model(s).
+ Where("id = ? AND expires_at > ?", sessionID, time.Now()).
+ Scan(r.Context()); err != nil {
+ return nil
+ }
+
+ return s
+ }
+
+ return nil
+}
+
+// MaxAge sets the maximum age for the store and the underlying cookie
+// implementation. Individual sessions can be deleted by setting
+// Options.MaxAge = -1 for that session.
+func (st *Store) MaxAge(age int) {
+ st.SessionOpts.MaxAge = age
+ for _, codec := range st.Codecs {
+ if sc, ok := codec.(*securecookie.SecureCookie); ok {
+ sc.MaxAge(age)
+ }
+ }
+}
+
+// MaxLength restricts the maximum length of new sessions to l.
+// If l is 0 there is no limit to the size of a session, use with caution.
+// The default is 4096 (default for securecookie).
+func (st *Store) MaxLength(l int) {
+ for _, c := range st.Codecs {
+ if codec, ok := c.(*securecookie.SecureCookie); ok {
+ codec.MaxLength(l)
+ }
+ }
+}
+
+// Cleanup deletes expired sessions.
+func (st *Store) Cleanup() {
+ _, err := st.db.NewDelete().Model(&Model{}).Where("expires_at <= ?", time.Now()).Exec(context.Background())
+ if err != nil {
+ slog.Default().With("error", err).Error("cleanup")
+ }
+}
+
+// PeriodicCleanup runs Cleanup every interval. Close quit channel to stop.
+func (st *Store) PeriodicCleanup(interval time.Duration, quit <-chan struct{}) {
+ t := time.NewTicker(interval)
+ defer t.Stop()
+
+ for {
+ select {
+ case <-t.C:
+ st.Cleanup()
+ case <-quit:
+ return
+ }
+ }
+}
diff --git a/pkg/middleware/state.go b/pkg/middleware/state.go
new file mode 100644
index 0000000..c918411
--- /dev/null
+++ b/pkg/middleware/state.go
@@ -0,0 +1,58 @@
+package middleware
+
+import (
+ "context"
+ "encoding/gob"
+
+ "github.com/gorilla/sessions"
+ "github.com/labstack/echo-contrib/session"
+ "github.com/labstack/echo/v4"
+)
+
+func init() {
+ gob.Register(&State{})
+}
+
+func PopulateState() echo.MiddlewareFunc {
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ sess, err := session.Get("state", c)
+ if err != nil {
+ return err
+ }
+
+ u := sess.Values["state"]
+ c.Set("state", u)
+
+ ctx := context.WithValue(c.Request().Context(), ContextKey("user"), u)
+ r := c.Request().WithContext(ctx)
+ c.SetRequest(r)
+
+ return next(c)
+ }
+ }
+}
+
+func SetState(c echo.Context, u State, maxage int) error {
+ sess, err := session.Get("state", c)
+ if err != nil {
+ return err
+ }
+
+ sess.Values["state"] = u
+ sess.Options = &sessions.Options{
+ Path: "/",
+ MaxAge: maxage,
+ HttpOnly: true,
+ Secure: true,
+ }
+
+ return sess.Save(c.Request(), c.Response())
+}
+
+type State struct {
+ Username string `json:"username"`
+ CurrentGUID string `json:"current_guid"`
+ Points int `json:"points"`
+ Image string `json:"image"`
+}