aboutsummaryrefslogtreecommitdiff
path: root/pkg/middleware/session
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/middleware/session')
-rw-r--r--pkg/middleware/session/store.go230
1 files changed, 230 insertions, 0 deletions
diff --git a/pkg/middleware/session/store.go b/pkg/middleware/session/store.go
new file mode 100644
index 0000000..04071c9
--- /dev/null
+++ b/pkg/middleware/session/store.go
@@ -0,0 +1,230 @@
+package session
+
+import (
+ "context"
+ "encoding/base32"
+ "log/slog"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/gorilla/securecookie"
+ "github.com/gorilla/sessions"
+ "github.com/uptrace/bun"
+)
+
+const sessionIDLen = 32
+const defaultTableName = "sessions"
+const defaultMaxAge = 60 * 60 * 24 * 30 // 30 days
+const defaultPath = "/"
+
+// Options for bunstore
+type Options struct {
+ TableName string
+ SkipCreateTable bool
+}
+
+// Store represent a bunstore
+type Store struct {
+ db *bun.DB
+ opts Options
+ Codecs []securecookie.Codec
+ SessionOpts *sessions.Options
+}
+
+type bunSession struct {
+ bun.BaseModel `bun:"table:sessions,alias:s"`
+
+ ID string `bun:",pk,unique"`
+ Data string
+ CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
+ UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
+ ExpiresAt time.Time
+}
+
+type KeyPairs []string
+
+func (k KeyPairs) ToKeys() [][]byte {
+ b := make([][]byte, 0, len(k))
+ for _, kk := range k {
+ b = append(b, []byte(kk))
+ }
+ return b
+}
+
+// New creates a new bunstore session
+func New(db *bun.DB, keyPairs KeyPairs) (*Store, error) {
+ return NewOptions(db, Options{}, keyPairs)
+}
+
+// NewOptions creates a new bunstore session with options
+func NewOptions(db *bun.DB, opts Options, keyPairs KeyPairs) (*Store, error) {
+ st := &Store{
+ db: db,
+ opts: opts,
+ Codecs: securecookie.CodecsFromPairs(keyPairs.ToKeys()...),
+ SessionOpts: &sessions.Options{
+ Path: defaultPath,
+ MaxAge: defaultMaxAge,
+ },
+ }
+ if st.opts.TableName == "" {
+ st.opts.TableName = defaultTableName
+ }
+
+ if !st.opts.SkipCreateTable {
+ model := &bunSession{}
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+ if _, err := db.NewCreateTable().IfNotExists().Model(model).Exec(ctx); err != nil {
+ return nil, err
+ }
+ if _, err := db.NewCreateIndex().Model(model).Column("expires_at").Exec(ctx); err != nil {
+ return nil, err
+ }
+ }
+
+ return st, nil
+}
+
+// Get returns a session for the given name after adding it to the registry.
+func (st *Store) Get(r *http.Request, name string) (*sessions.Session, error) {
+ return sessions.GetRegistry(r).Get(st, name)
+}
+
+// New creates a session with name without adding it to the registry.
+func (st *Store) New(r *http.Request, name string) (*sessions.Session, error) {
+ session := sessions.NewSession(st, name)
+ opts := *st.SessionOpts
+ session.Options = &opts
+ session.IsNew = true
+
+ st.MaxAge(st.SessionOpts.MaxAge)
+
+ // try fetch from db if there is a cookie
+ s := st.getSessionFromCookie(r, session.Name())
+ if s != nil {
+ if err := securecookie.DecodeMulti(session.Name(), s.Data, &session.Values, st.Codecs...); err != nil {
+ return session, nil
+ }
+ session.ID = s.ID
+ session.IsNew = false
+ }
+
+ return session, nil
+}
+
+// Save session and set cookie header
+func (st *Store) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
+ s := st.getSessionFromCookie(r, session.Name())
+
+ // delete if max age is < 0
+ if session.Options.MaxAge < 0 {
+ if s != nil {
+ if _, err := st.db.NewDelete().Model(&bunSession{ID: session.ID}).Exec(r.Context()); err != nil {
+ return err
+ }
+ }
+ http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
+ return nil
+ }
+
+ data, err := securecookie.EncodeMulti(session.Name(), session.Values, st.Codecs...)
+ if err != nil {
+ return err
+ }
+ now := time.Now()
+ expire := now.Add(time.Second * time.Duration(session.Options.MaxAge))
+
+ if s == nil {
+ // generate random session ID key suitable for storage in the db
+ session.ID = strings.TrimRight(
+ base32.StdEncoding.EncodeToString(
+ securecookie.GenerateRandomKey(sessionIDLen)), "=")
+ s = &bunSession{
+ ID: session.ID,
+ Data: data,
+ ExpiresAt: expire,
+ }
+ if _, err := st.db.NewInsert().Model(s).Exec(r.Context()); err != nil {
+ return err
+ }
+ } else {
+ s.Data = data
+ s.ExpiresAt = expire
+ if _, err := st.db.NewUpdate().Model(s).WherePK("id").Column("data", "expires_at").Exec(r.Context()); err != nil {
+ return err
+ }
+ }
+
+ // set session id cookie
+ id, err := securecookie.EncodeMulti(session.Name(), s.ID, st.Codecs...)
+ if err != nil {
+ return err
+ }
+ http.SetCookie(w, sessions.NewCookie(session.Name(), id, session.Options))
+
+ return nil
+}
+
+// getSessionFromCookie looks for an existing bunSession from a session ID stored inside a cookie
+func (st *Store) getSessionFromCookie(r *http.Request, name string) *bunSession {
+ if cookie, err := r.Cookie(name); err == nil {
+ sessionID := ""
+ if err := securecookie.DecodeMulti(name, cookie.Value, &sessionID, st.Codecs...); err != nil {
+ return nil
+ }
+ s := &bunSession{}
+ err := st.db.NewSelect().Model(s).Where("id = ? AND expires_at > ?", sessionID, time.Now()).Scan(r.Context())
+ if err != nil {
+ return nil
+ }
+ return s
+ }
+ return nil
+}
+
+// MaxAge sets the maximum age for the store and the underlying cookie
+// implementation. Individual sessions can be deleted by setting
+// Options.MaxAge = -1 for that session.
+func (st *Store) MaxAge(age int) {
+ st.SessionOpts.MaxAge = age
+ for _, codec := range st.Codecs {
+ if sc, ok := codec.(*securecookie.SecureCookie); ok {
+ sc.MaxAge(age)
+ }
+ }
+}
+
+// MaxLength restricts the maximum length of new sessions to l.
+// If l is 0 there is no limit to the size of a session, use with caution.
+// The default is 4096 (default for securecookie)
+func (st *Store) MaxLength(l int) {
+ for _, c := range st.Codecs {
+ if codec, ok := c.(*securecookie.SecureCookie); ok {
+ codec.MaxLength(l)
+ }
+ }
+}
+
+// Cleanup deletes expired sessions
+func (st *Store) Cleanup() {
+ _, err := st.db.NewDelete().Model(&bunSession{}).Where("expires_at <= ?", time.Now()).Exec(context.Background())
+ if err != nil {
+ slog.Default().With("error", err).Error("cleanup")
+ }
+}
+
+// PeriodicCleanup runs Cleanup every interval. Close quit channel to stop.
+func (st *Store) PeriodicCleanup(interval time.Duration, quit <-chan struct{}) {
+ t := time.NewTicker(interval)
+ defer t.Stop()
+ for {
+ select {
+ case <-t.C:
+ st.Cleanup()
+ case <-quit:
+ return
+ }
+ }
+}