package session
import (
"context"
"encoding/base32"
"log/slog"
"net/http"
"strings"
"time"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
"github.com/uptrace/bun"
)
const (
sessionIDLen = 32
defaultTableName = "sessions"
defaultMaxAge = 60 * 60 * 24 * 30 // 30 days
defaultPath = "/"
)
// Options for bunstore.
type Options struct {
TableName string
SkipCreateTable bool
}
// Store represent a bunstore.
type Store struct {
db *bun.DB
opts Options
Codecs []securecookie.Codec
SessionOpts *sessions.Options
}
type Model struct {
bun.BaseModel `bun:"table:sessions,alias:s"`
ID string `bun:",pk,unique"`
Data string
CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
ExpiresAt time.Time
}
type KeyPairs []string
func (k KeyPairs) ToKeys() [][]byte {
b := make([][]byte, 0, len(k))
for _, kk := range k {
b = append(b, []byte(kk))
}
return b
}
// New creates a new bunstore session.
func New(db *bun.DB, keyPairs KeyPairs) (*Store, error) {
return NewOptions(db, Options{}, keyPairs)
}
// NewOptions creates a new bunstore session with options.
func NewOptions(db *bun.DB, opts Options, keyPairs KeyPairs) (*Store, error) {
st := &Store{
db: db,
opts: opts,
Codecs: securecookie.CodecsFromPairs(keyPairs.ToKeys()...),
SessionOpts: &sessions.Options{
Path: defaultPath,
MaxAge: defaultMaxAge,
},
}
return st, nil
}
// Get returns a session for the given name after adding it to the registry.
func (st *Store) Get(r *http.Request, name string) (*sessions.Session, error) {
return sessions.GetRegistry(r).Get(st, name)
}
// New creates a session with name without adding it to the registry.
func (st *Store) New(r *http.Request, name string) (*sessions.Session, error) {
session := sessions.NewSession(st, name)
opts := *st.SessionOpts
session.Options = &opts
session.IsNew = true
st.MaxAge(st.SessionOpts.MaxAge)
// try fetch from db if there is a cookie
s := st.getSessionFromCookie(r, session.Name())
if s != nil {
if err := securecookie.DecodeMulti(session.Name(), s.Data, &session.Values, st.Codecs...); err != nil {
//nolint:nilerr
return session, nil
}
session.ID = s.ID
session.IsNew = false
}
return session, nil
}
// Save session and set cookie header.
func (st *Store) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
s := st.getSessionFromCookie(r, session.Name())
// delete if max age is < 0
if session.Options.MaxAge < 0 {
if s != nil {
if _, err := st.db.NewDelete().Model(&Model{ID: session.ID}).WherePK("id").Exec(r.Context()); err != nil {
return err
}
}
http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
return nil
}
data, err := securecookie.EncodeMulti(session.Name(), session.Values, st.Codecs...)
if err != nil {
return err
}
now := time.Now()
expire := now.Add(time.Second * time.Duration(session.Options.MaxAge))
if s == nil {
// generate random session ID key suitable for storage in the db
session.ID = strings.TrimRight(
base32.StdEncoding.EncodeToString(
securecookie.GenerateRandomKey(sessionIDLen)), "=")
s = &Model{
ID: session.ID,
Data: data,
ExpiresAt: expire,
}
if _, err := st.db.NewInsert().Model(s).Exec(r.Context()); err != nil {
return err
}
} else {
s.Data = data
s.ExpiresAt = expire
if _, err := st.db.NewUpdate().Model(s).WherePK("id").Column("data", "expires_at").Exec(r.Context()); err != nil {
return err
}
}
// set session id cookie
id, err := securecookie.EncodeMulti(session.Name(), s.ID, st.Codecs...)
if err != nil {
return err
}
http.SetCookie(w, sessions.NewCookie(session.Name(), id, session.Options))
return nil
}
// getSessionFromCookie looks for an existing bunSession from a session ID stored inside a cookie.
func (st *Store) getSessionFromCookie(r *http.Request, name string) *Model {
if cookie, err := r.Cookie(name); err == nil {
sessionID := ""
if err := securecookie.DecodeMulti(name, cookie.Value, &sessionID, st.Codecs...); err != nil {
return nil
}
s := &Model{}
if err := st.db.NewSelect().
Model(s).
Where("id = ? AND expires_at > ?", sessionID, time.Now()).
Scan(r.Context()); err != nil {
return nil
}
return s
}
return nil
}
// MaxAge sets the maximum age for the store and the underlying cookie
// implementation. Individual sessions can be deleted by setting
// Options.MaxAge = -1 for that session.
func (st *Store) MaxAge(age int) {
st.SessionOpts.MaxAge = age
for _, codec := range st.Codecs {
if sc, ok := codec.(*securecookie.SecureCookie); ok {
sc.MaxAge(age)
}
}
}
// MaxLength restricts the maximum length of new sessions to l.
// If l is 0 there is no limit to the size of a session, use with caution.
// The default is 4096 (default for securecookie).
func (st *Store) MaxLength(l int) {
for _, c := range st.Codecs {
if codec, ok := c.(*securecookie.SecureCookie); ok {
codec.MaxLength(l)
}
}
}
// Cleanup deletes expired sessions.
func (st *Store) Cleanup() {
_, err := st.db.NewDelete().Model(&Model{}).Where("expires_at <= ?", time.Now()).Exec(context.Background())
if err != nil {
slog.Default().With("error", err).Error("cleanup")
}
}
// PeriodicCleanup runs Cleanup every interval. Close quit channel to stop.
func (st *Store) PeriodicCleanup(interval time.Duration, quit <-chan struct{}) {
t := time.NewTicker(interval)
defer t.Stop()
for {
select {
case <-t.C:
st.Cleanup()
case <-quit:
return
}
}
}