diff options
author | Alexander Neonxp Kiryukhin <i@neonxp.ru> | 2024-10-08 03:43:08 +0300 |
---|---|---|
committer | Alexander Neonxp Kiryukhin <i@neonxp.ru> | 2024-10-08 03:50:53 +0300 |
commit | e849e705c30cceec3cf7336a21bed96c8a911e90 (patch) | |
tree | 93f559bcd4cf3e53193930d112e564a2b7462ac8 /pkg/middleware/session/store.go | |
parent | 3ee654f6fb3cdf119630bfba8066c96ec26428c3 (diff) |
Добавил рейтинг
Добавил страницу топа
Добавил rss/xml/json feed
Diffstat (limited to 'pkg/middleware/session/store.go')
-rw-r--r-- | pkg/middleware/session/store.go | 230 |
1 files changed, 230 insertions, 0 deletions
diff --git a/pkg/middleware/session/store.go b/pkg/middleware/session/store.go new file mode 100644 index 0000000..04071c9 --- /dev/null +++ b/pkg/middleware/session/store.go @@ -0,0 +1,230 @@ +package session + +import ( + "context" + "encoding/base32" + "log/slog" + "net/http" + "strings" + "time" + + "github.com/gorilla/securecookie" + "github.com/gorilla/sessions" + "github.com/uptrace/bun" +) + +const sessionIDLen = 32 +const defaultTableName = "sessions" +const defaultMaxAge = 60 * 60 * 24 * 30 // 30 days +const defaultPath = "/" + +// Options for bunstore +type Options struct { + TableName string + SkipCreateTable bool +} + +// Store represent a bunstore +type Store struct { + db *bun.DB + opts Options + Codecs []securecookie.Codec + SessionOpts *sessions.Options +} + +type bunSession struct { + bun.BaseModel `bun:"table:sessions,alias:s"` + + ID string `bun:",pk,unique"` + Data string + CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` + UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` + ExpiresAt time.Time +} + +type KeyPairs []string + +func (k KeyPairs) ToKeys() [][]byte { + b := make([][]byte, 0, len(k)) + for _, kk := range k { + b = append(b, []byte(kk)) + } + return b +} + +// New creates a new bunstore session +func New(db *bun.DB, keyPairs KeyPairs) (*Store, error) { + return NewOptions(db, Options{}, keyPairs) +} + +// NewOptions creates a new bunstore session with options +func NewOptions(db *bun.DB, opts Options, keyPairs KeyPairs) (*Store, error) { + st := &Store{ + db: db, + opts: opts, + Codecs: securecookie.CodecsFromPairs(keyPairs.ToKeys()...), + SessionOpts: &sessions.Options{ + Path: defaultPath, + MaxAge: defaultMaxAge, + }, + } + if st.opts.TableName == "" { + st.opts.TableName = defaultTableName + } + + if !st.opts.SkipCreateTable { + model := &bunSession{} + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if _, err := db.NewCreateTable().IfNotExists().Model(model).Exec(ctx); err != nil { + return nil, err + } + if _, err := db.NewCreateIndex().Model(model).Column("expires_at").Exec(ctx); err != nil { + return nil, err + } + } + + return st, nil +} + +// Get returns a session for the given name after adding it to the registry. +func (st *Store) Get(r *http.Request, name string) (*sessions.Session, error) { + return sessions.GetRegistry(r).Get(st, name) +} + +// New creates a session with name without adding it to the registry. +func (st *Store) New(r *http.Request, name string) (*sessions.Session, error) { + session := sessions.NewSession(st, name) + opts := *st.SessionOpts + session.Options = &opts + session.IsNew = true + + st.MaxAge(st.SessionOpts.MaxAge) + + // try fetch from db if there is a cookie + s := st.getSessionFromCookie(r, session.Name()) + if s != nil { + if err := securecookie.DecodeMulti(session.Name(), s.Data, &session.Values, st.Codecs...); err != nil { + return session, nil + } + session.ID = s.ID + session.IsNew = false + } + + return session, nil +} + +// Save session and set cookie header +func (st *Store) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { + s := st.getSessionFromCookie(r, session.Name()) + + // delete if max age is < 0 + if session.Options.MaxAge < 0 { + if s != nil { + if _, err := st.db.NewDelete().Model(&bunSession{ID: session.ID}).Exec(r.Context()); err != nil { + return err + } + } + http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options)) + return nil + } + + data, err := securecookie.EncodeMulti(session.Name(), session.Values, st.Codecs...) + if err != nil { + return err + } + now := time.Now() + expire := now.Add(time.Second * time.Duration(session.Options.MaxAge)) + + if s == nil { + // generate random session ID key suitable for storage in the db + session.ID = strings.TrimRight( + base32.StdEncoding.EncodeToString( + securecookie.GenerateRandomKey(sessionIDLen)), "=") + s = &bunSession{ + ID: session.ID, + Data: data, + ExpiresAt: expire, + } + if _, err := st.db.NewInsert().Model(s).Exec(r.Context()); err != nil { + return err + } + } else { + s.Data = data + s.ExpiresAt = expire + if _, err := st.db.NewUpdate().Model(s).WherePK("id").Column("data", "expires_at").Exec(r.Context()); err != nil { + return err + } + } + + // set session id cookie + id, err := securecookie.EncodeMulti(session.Name(), s.ID, st.Codecs...) + if err != nil { + return err + } + http.SetCookie(w, sessions.NewCookie(session.Name(), id, session.Options)) + + return nil +} + +// getSessionFromCookie looks for an existing bunSession from a session ID stored inside a cookie +func (st *Store) getSessionFromCookie(r *http.Request, name string) *bunSession { + if cookie, err := r.Cookie(name); err == nil { + sessionID := "" + if err := securecookie.DecodeMulti(name, cookie.Value, &sessionID, st.Codecs...); err != nil { + return nil + } + s := &bunSession{} + err := st.db.NewSelect().Model(s).Where("id = ? AND expires_at > ?", sessionID, time.Now()).Scan(r.Context()) + if err != nil { + return nil + } + return s + } + return nil +} + +// MaxAge sets the maximum age for the store and the underlying cookie +// implementation. Individual sessions can be deleted by setting +// Options.MaxAge = -1 for that session. +func (st *Store) MaxAge(age int) { + st.SessionOpts.MaxAge = age + for _, codec := range st.Codecs { + if sc, ok := codec.(*securecookie.SecureCookie); ok { + sc.MaxAge(age) + } + } +} + +// MaxLength restricts the maximum length of new sessions to l. +// If l is 0 there is no limit to the size of a session, use with caution. +// The default is 4096 (default for securecookie) +func (st *Store) MaxLength(l int) { + for _, c := range st.Codecs { + if codec, ok := c.(*securecookie.SecureCookie); ok { + codec.MaxLength(l) + } + } +} + +// Cleanup deletes expired sessions +func (st *Store) Cleanup() { + _, err := st.db.NewDelete().Model(&bunSession{}).Where("expires_at <= ?", time.Now()).Exec(context.Background()) + if err != nil { + slog.Default().With("error", err).Error("cleanup") + } +} + +// PeriodicCleanup runs Cleanup every interval. Close quit channel to stop. +func (st *Store) PeriodicCleanup(interval time.Duration, quit <-chan struct{}) { + t := time.NewTicker(interval) + defer t.Stop() + for { + select { + case <-t.C: + st.Cleanup() + case <-quit: + return + } + } +} |