aboutsummaryrefslogblamecommitdiff
path: root/pkg/middleware/session/store.go
blob: 04071c900fdf0b325e1d4ace37dc5d48cf8deba2 (plain) (tree)





































































































































































































































                                                                                                                                  
package session

import (
	"context"
	"encoding/base32"
	"log/slog"
	"net/http"
	"strings"
	"time"

	"github.com/gorilla/securecookie"
	"github.com/gorilla/sessions"
	"github.com/uptrace/bun"
)

const sessionIDLen = 32
const defaultTableName = "sessions"
const defaultMaxAge = 60 * 60 * 24 * 30 // 30 days
const defaultPath = "/"

// Options for bunstore
type Options struct {
	TableName       string
	SkipCreateTable bool
}

// Store represent a bunstore
type Store struct {
	db          *bun.DB
	opts        Options
	Codecs      []securecookie.Codec
	SessionOpts *sessions.Options
}

type bunSession struct {
	bun.BaseModel `bun:"table:sessions,alias:s"`

	ID        string `bun:",pk,unique"`
	Data      string
	CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
	UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
	ExpiresAt time.Time
}

type KeyPairs []string

func (k KeyPairs) ToKeys() [][]byte {
	b := make([][]byte, 0, len(k))
	for _, kk := range k {
		b = append(b, []byte(kk))
	}
	return b
}

// New creates a new bunstore session
func New(db *bun.DB, keyPairs KeyPairs) (*Store, error) {
	return NewOptions(db, Options{}, keyPairs)
}

// NewOptions creates a new bunstore session with options
func NewOptions(db *bun.DB, opts Options, keyPairs KeyPairs) (*Store, error) {
	st := &Store{
		db:     db,
		opts:   opts,
		Codecs: securecookie.CodecsFromPairs(keyPairs.ToKeys()...),
		SessionOpts: &sessions.Options{
			Path:   defaultPath,
			MaxAge: defaultMaxAge,
		},
	}
	if st.opts.TableName == "" {
		st.opts.TableName = defaultTableName
	}

	if !st.opts.SkipCreateTable {
		model := &bunSession{}
		ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
		defer cancel()
		if _, err := db.NewCreateTable().IfNotExists().Model(model).Exec(ctx); err != nil {
			return nil, err
		}
		if _, err := db.NewCreateIndex().Model(model).Column("expires_at").Exec(ctx); err != nil {
			return nil, err
		}
	}

	return st, nil
}

// Get returns a session for the given name after adding it to the registry.
func (st *Store) Get(r *http.Request, name string) (*sessions.Session, error) {
	return sessions.GetRegistry(r).Get(st, name)
}

// New creates a session with name without adding it to the registry.
func (st *Store) New(r *http.Request, name string) (*sessions.Session, error) {
	session := sessions.NewSession(st, name)
	opts := *st.SessionOpts
	session.Options = &opts
	session.IsNew = true

	st.MaxAge(st.SessionOpts.MaxAge)

	// try fetch from db if there is a cookie
	s := st.getSessionFromCookie(r, session.Name())
	if s != nil {
		if err := securecookie.DecodeMulti(session.Name(), s.Data, &session.Values, st.Codecs...); err != nil {
			return session, nil
		}
		session.ID = s.ID
		session.IsNew = false
	}

	return session, nil
}

// Save session and set cookie header
func (st *Store) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
	s := st.getSessionFromCookie(r, session.Name())

	// delete if max age is < 0
	if session.Options.MaxAge < 0 {
		if s != nil {
			if _, err := st.db.NewDelete().Model(&bunSession{ID: session.ID}).Exec(r.Context()); err != nil {
				return err
			}
		}
		http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
		return nil
	}

	data, err := securecookie.EncodeMulti(session.Name(), session.Values, st.Codecs...)
	if err != nil {
		return err
	}
	now := time.Now()
	expire := now.Add(time.Second * time.Duration(session.Options.MaxAge))

	if s == nil {
		// generate random session ID key suitable for storage in the db
		session.ID = strings.TrimRight(
			base32.StdEncoding.EncodeToString(
				securecookie.GenerateRandomKey(sessionIDLen)), "=")
		s = &bunSession{
			ID:        session.ID,
			Data:      data,
			ExpiresAt: expire,
		}
		if _, err := st.db.NewInsert().Model(s).Exec(r.Context()); err != nil {
			return err
		}
	} else {
		s.Data = data
		s.ExpiresAt = expire
		if _, err := st.db.NewUpdate().Model(s).WherePK("id").Column("data", "expires_at").Exec(r.Context()); err != nil {
			return err
		}
	}

	// set session id cookie
	id, err := securecookie.EncodeMulti(session.Name(), s.ID, st.Codecs...)
	if err != nil {
		return err
	}
	http.SetCookie(w, sessions.NewCookie(session.Name(), id, session.Options))

	return nil
}

// getSessionFromCookie looks for an existing bunSession from a session ID stored inside a cookie
func (st *Store) getSessionFromCookie(r *http.Request, name string) *bunSession {
	if cookie, err := r.Cookie(name); err == nil {
		sessionID := ""
		if err := securecookie.DecodeMulti(name, cookie.Value, &sessionID, st.Codecs...); err != nil {
			return nil
		}
		s := &bunSession{}
		err := st.db.NewSelect().Model(s).Where("id = ? AND expires_at > ?", sessionID, time.Now()).Scan(r.Context())
		if err != nil {
			return nil
		}
		return s
	}
	return nil
}

// MaxAge sets the maximum age for the store and the underlying cookie
// implementation. Individual sessions can be deleted by setting
// Options.MaxAge = -1 for that session.
func (st *Store) MaxAge(age int) {
	st.SessionOpts.MaxAge = age
	for _, codec := range st.Codecs {
		if sc, ok := codec.(*securecookie.SecureCookie); ok {
			sc.MaxAge(age)
		}
	}
}

// MaxLength restricts the maximum length of new sessions to l.
// If l is 0 there is no limit to the size of a session, use with caution.
// The default is 4096 (default for securecookie)
func (st *Store) MaxLength(l int) {
	for _, c := range st.Codecs {
		if codec, ok := c.(*securecookie.SecureCookie); ok {
			codec.MaxLength(l)
		}
	}
}

// Cleanup deletes expired sessions
func (st *Store) Cleanup() {
	_, err := st.db.NewDelete().Model(&bunSession{}).Where("expires_at <= ?", time.Now()).Exec(context.Background())
	if err != nil {
		slog.Default().With("error", err).Error("cleanup")
	}
}

// PeriodicCleanup runs Cleanup every interval. Close quit channel to stop.
func (st *Store) PeriodicCleanup(interval time.Duration, quit <-chan struct{}) {
	t := time.NewTicker(interval)
	defer t.Stop()
	for {
		select {
		case <-t.C:
			st.Cleanup()
		case <-quit:
			return
		}
	}
}