diff options
Diffstat (limited to 'pkg/middleware/session/store.go')
-rw-r--r-- | pkg/middleware/session/store.go | 232 |
1 files changed, 232 insertions, 0 deletions
diff --git a/pkg/middleware/session/store.go b/pkg/middleware/session/store.go new file mode 100644 index 0000000..78172d7 --- /dev/null +++ b/pkg/middleware/session/store.go @@ -0,0 +1,232 @@ +package session + +import ( + "context" + "encoding/base32" + "log/slog" + "net/http" + "strings" + "time" + + "github.com/gorilla/securecookie" + "github.com/gorilla/sessions" + "github.com/uptrace/bun" +) + +const ( + sessionIDLen = 32 + defaultTableName = "sessions" + defaultMaxAge = 60 * 60 * 24 * 30 // 30 days + defaultPath = "/" +) + +// Options for bunstore. +type Options struct { + TableName string + SkipCreateTable bool +} + +// Store represent a bunstore. +type Store struct { + db *bun.DB + opts Options + Codecs []securecookie.Codec + SessionOpts *sessions.Options +} + +type Model struct { + bun.BaseModel `bun:"table:sessions,alias:s"` + + ID string `bun:",pk,unique"` + Data string + CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` + UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` + ExpiresAt time.Time +} + +type KeyPairs []string + +func (k KeyPairs) ToKeys() [][]byte { + b := make([][]byte, 0, len(k)) + for _, kk := range k { + b = append(b, []byte(kk)) + } + + return b +} + +// New creates a new bunstore session. +func New(db *bun.DB, keyPairs KeyPairs) (*Store, error) { + return NewOptions(db, Options{}, keyPairs) +} + +// NewOptions creates a new bunstore session with options. +func NewOptions(db *bun.DB, opts Options, keyPairs KeyPairs) (*Store, error) { + st := &Store{ + db: db, + opts: opts, + Codecs: securecookie.CodecsFromPairs(keyPairs.ToKeys()...), + SessionOpts: &sessions.Options{ + Path: defaultPath, + MaxAge: defaultMaxAge, + }, + } + + return st, nil +} + +// Get returns a session for the given name after adding it to the registry. +func (st *Store) Get(r *http.Request, name string) (*sessions.Session, error) { + return sessions.GetRegistry(r).Get(st, name) +} + +// New creates a session with name without adding it to the registry. +func (st *Store) New(r *http.Request, name string) (*sessions.Session, error) { + session := sessions.NewSession(st, name) + opts := *st.SessionOpts + session.Options = &opts + session.IsNew = true + + st.MaxAge(st.SessionOpts.MaxAge) + + // try fetch from db if there is a cookie + s := st.getSessionFromCookie(r, session.Name()) + if s != nil { + if err := securecookie.DecodeMulti(session.Name(), s.Data, &session.Values, st.Codecs...); err != nil { + //nolint:nilerr + return session, nil + } + + session.ID = s.ID + session.IsNew = false + } + + return session, nil +} + +// Save session and set cookie header. +func (st *Store) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { + s := st.getSessionFromCookie(r, session.Name()) + + // delete if max age is < 0 + if session.Options.MaxAge < 0 { + if s != nil { + if _, err := st.db.NewDelete().Model(&Model{ID: session.ID}).WherePK("id").Exec(r.Context()); err != nil { + return err + } + } + + http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options)) + + return nil + } + + data, err := securecookie.EncodeMulti(session.Name(), session.Values, st.Codecs...) + if err != nil { + return err + } + + now := time.Now() + expire := now.Add(time.Second * time.Duration(session.Options.MaxAge)) + + if s == nil { + // generate random session ID key suitable for storage in the db + session.ID = strings.TrimRight( + base32.StdEncoding.EncodeToString( + securecookie.GenerateRandomKey(sessionIDLen)), "=") + s = &Model{ + ID: session.ID, + Data: data, + ExpiresAt: expire, + } + + if _, err := st.db.NewInsert().Model(s).Exec(r.Context()); err != nil { + return err + } + } else { + s.Data = data + s.ExpiresAt = expire + + if _, err := st.db.NewUpdate().Model(s).WherePK("id").Column("data", "expires_at").Exec(r.Context()); err != nil { + return err + } + } + + // set session id cookie + id, err := securecookie.EncodeMulti(session.Name(), s.ID, st.Codecs...) + if err != nil { + return err + } + + http.SetCookie(w, sessions.NewCookie(session.Name(), id, session.Options)) + + return nil +} + +// getSessionFromCookie looks for an existing bunSession from a session ID stored inside a cookie. +func (st *Store) getSessionFromCookie(r *http.Request, name string) *Model { + if cookie, err := r.Cookie(name); err == nil { + sessionID := "" + if err := securecookie.DecodeMulti(name, cookie.Value, &sessionID, st.Codecs...); err != nil { + return nil + } + + s := &Model{} + if err := st.db.NewSelect(). + Model(s). + Where("id = ? AND expires_at > ?", sessionID, time.Now()). + Scan(r.Context()); err != nil { + return nil + } + + return s + } + + return nil +} + +// MaxAge sets the maximum age for the store and the underlying cookie +// implementation. Individual sessions can be deleted by setting +// Options.MaxAge = -1 for that session. +func (st *Store) MaxAge(age int) { + st.SessionOpts.MaxAge = age + for _, codec := range st.Codecs { + if sc, ok := codec.(*securecookie.SecureCookie); ok { + sc.MaxAge(age) + } + } +} + +// MaxLength restricts the maximum length of new sessions to l. +// If l is 0 there is no limit to the size of a session, use with caution. +// The default is 4096 (default for securecookie). +func (st *Store) MaxLength(l int) { + for _, c := range st.Codecs { + if codec, ok := c.(*securecookie.SecureCookie); ok { + codec.MaxLength(l) + } + } +} + +// Cleanup deletes expired sessions. +func (st *Store) Cleanup() { + _, err := st.db.NewDelete().Model(&Model{}).Where("expires_at <= ?", time.Now()).Exec(context.Background()) + if err != nil { + slog.Default().With("error", err).Error("cleanup") + } +} + +// PeriodicCleanup runs Cleanup every interval. Close quit channel to stop. +func (st *Store) PeriodicCleanup(interval time.Duration, quit <-chan struct{}) { + t := time.NewTicker(interval) + defer t.Stop() + + for { + select { + case <-t.C: + st.Cleanup() + case <-quit: + return + } + } +} |