diff options
Diffstat (limited to 'middleware')
-rw-r--r-- | middleware/context.go | 4 | ||||
-rw-r--r-- | middleware/session/bbolt.go | 11 | ||||
-rw-r--r-- | middleware/session/context.go | 9 | ||||
-rw-r--r-- | middleware/session/memstore.go | 11 | ||||
-rw-r--r-- | middleware/session/session.go | 114 | ||||
-rw-r--r-- | middleware/session/store.go | 17 |
6 files changed, 97 insertions, 69 deletions
diff --git a/middleware/context.go b/middleware/context.go index a76bca0..8671140 100644 --- a/middleware/context.go +++ b/middleware/context.go @@ -4,8 +4,4 @@ type ctxKey int const ( requestIDKey ctxKey = iota - SessionIDKey - SessionValueKey - SessionConfigKey - SessionStorerKey ) diff --git a/middleware/session/bbolt.go b/middleware/session/bbolt.go index be484d7..bf46953 100644 --- a/middleware/session/bbolt.go +++ b/middleware/session/bbolt.go @@ -2,7 +2,6 @@ package session import ( "bytes" - "context" "encoding/gob" "log/slog" @@ -21,8 +20,8 @@ type BoltStore struct { bucketName []byte } -func (s *BoltStore) Load(ctx context.Context, sessionID string) Value { - v := Value{} +func (s *BoltStore) Load(sessionID string) Values { + v := Values{} err := s.db.View(func(tx *bbolt.Tx) error { bucket := tx.Bucket(s.bucketName) if bucket == nil { @@ -39,12 +38,12 @@ func (s *BoltStore) Load(ctx context.Context, sessionID string) Value { return gob.NewDecoder(rdr).Decode(&v) }) if err != nil { - slog.WarnContext(ctx, "failed load session", slog.Any("error", err)) + slog.Warn("failed load session", slog.Any("error", err)) } return v } -func (s *BoltStore) Save(ctx context.Context, sessionID string, value Value) error { +func (s *BoltStore) Save(sessionID string, value Values) error { return s.db.Update(func(tx *bbolt.Tx) error { bucket, err := tx.CreateBucketIfNotExists(s.bucketName) if err != nil { @@ -59,7 +58,7 @@ func (s *BoltStore) Save(ctx context.Context, sessionID string, value Value) err }) } -func (s *BoltStore) Remove(ctx context.Context, sessionID string) error { +func (s *BoltStore) Remove(sessionID string) error { return s.db.Update(func(tx *bbolt.Tx) error { bucket, err := tx.CreateBucketIfNotExists(s.bucketName) if err != nil { diff --git a/middleware/session/context.go b/middleware/session/context.go new file mode 100644 index 0000000..870957d --- /dev/null +++ b/middleware/session/context.go @@ -0,0 +1,9 @@ +package session + +type ctxKey int + +const ( + sessionManagerKey ctxKey = iota + sessionIDKey + sessionValueKey +) diff --git a/middleware/session/memstore.go b/middleware/session/memstore.go index 2fcef39..d8cb958 100644 --- a/middleware/session/memstore.go +++ b/middleware/session/memstore.go @@ -1,7 +1,6 @@ package session import ( - "context" "sync" ) @@ -9,22 +8,22 @@ type MemoryStore struct { store sync.Map } -func (s *MemoryStore) Load(ctx context.Context, sessionID string) Value { +func (s *MemoryStore) Load(sessionID string) Values { val, ok := s.store.Load(sessionID) if ok { - return val.(Value) + return val.(Values) } - return Value{} + return Values{} } -func (s *MemoryStore) Save(ctx context.Context, sessionID string, value Value) error { +func (s *MemoryStore) Save(sessionID string, value Values) error { s.store.Store(sessionID, value) return nil } -func (s *MemoryStore) Remove(ctx context.Context, sessionID string) error { +func (s *MemoryStore) Remove(sessionID string) error { s.store.Delete(sessionID) return nil diff --git a/middleware/session/session.go b/middleware/session/session.go index 4d02bf1..47fc0fb 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -4,10 +4,9 @@ import ( "context" "errors" "net/http" - "sync" + "time" "go.neonxp.ru/mux" - "go.neonxp.ru/mux/middleware" "go.neonxp.ru/objectid" ) @@ -17,7 +16,7 @@ type Config struct { Domain string Secure bool HttpOnly bool - MaxAge int + MaxAge time.Duration } var DefaultConfig Config = Config{ @@ -26,69 +25,104 @@ var DefaultConfig Config = Config{ Domain: "", Secure: false, HttpOnly: true, - MaxAge: 30 * 3600, + MaxAge: 365 * 24 * time.Hour, } -func Middleware(config Config, storer Store) mux.Middleware { - if storer == nil { - storer = &MemoryStore{store: sync.Map{}} +var ( + ErrSessionNotFound = errors.New("session not found") + ErrNoSessionInContext = errors.New("no session in context") +) + +type SessionManager struct { + config *Config + storer Store +} + +func New(storer Store) *SessionManager { + return NewWithConfig(&DefaultConfig, storer) +} + +func NewWithConfig(config *Config, storer Store) *SessionManager { + return &SessionManager{ + config: config, + storer: storer, } +} +func (s *SessionManager) Middleware() mux.Middleware { return func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var ( sessionID string - values Value + values Values ) - cookie, err := r.Cookie(config.SessionCookie) + cookie, err := r.Cookie(s.config.SessionCookie) switch { case err == nil: sessionID = cookie.Value - values = storer.Load(r.Context(), sessionID) + values = s.storer.Load(sessionID) case errors.Is(err, http.ErrNoCookie): sessionID = objectid.New().String() - values = Value{} } - http.SetCookie(w, &http.Cookie{ - Name: config.SessionCookie, - Value: sessionID, - Path: config.Path, - Domain: config.Domain, - Secure: config.Secure, - HttpOnly: config.HttpOnly, - MaxAge: config.MaxAge, - }) - - ctx := context.WithValue(r.Context(), middleware.SessionValueKey, &values) - ctx = context.WithValue(ctx, middleware.SessionIDKey, sessionID) - ctx = context.WithValue(ctx, middleware.SessionConfigKey, config) - ctx = context.WithValue(ctx, middleware.SessionStorerKey, storer) + ctx := context.WithValue(r.Context(), sessionManagerKey, s) + ctx = context.WithValue(ctx, sessionIDKey, sessionID) + ctx = context.WithValue(ctx, sessionValueKey, values) h.ServeHTTP(w, r.WithContext(ctx)) - - storer.Save(r.Context(), sessionID, values) - }) } } -func FromRequest(r *http.Request) *Value { - return r.Context().Value(middleware.SessionValueKey).(*Value) +func (s *SessionManager) Values(ctx context.Context) Values { + aValue := ctx.Value(sessionValueKey) + values, ok := aValue.(Values) + if !ok || values == nil { + values = Values{} + } + + return values } -func Clear(w http.ResponseWriter, r *http.Request) { - storer := r.Context().Value(middleware.SessionStorerKey).(Store) - sessionID := r.Context().Value(middleware.SessionIDKey).(string) - storer.Remove(r.Context(), sessionID) - config := r.Context().Value(middleware.SessionConfigKey).(Config) +func (s *SessionManager) Save(w http.ResponseWriter, r *http.Request, values Values) error { + aSessionID := r.Context().Value(sessionIDKey) + sessionID, ok := aSessionID.(string) + if !ok { + return ErrNoSessionInContext + } + http.SetCookie(w, &http.Cookie{ - Name: config.SessionCookie, + Name: s.config.SessionCookie, Value: sessionID, - Path: config.Path, - Domain: config.Domain, - Secure: config.Secure, - HttpOnly: config.HttpOnly, + Path: s.config.Path, + Domain: s.config.Domain, + Secure: s.config.Secure, + HttpOnly: s.config.HttpOnly, + MaxAge: int(s.config.MaxAge.Seconds()), + }) + + return s.storer.Save(sessionID, values) +} +func (s *SessionManager) Clear(w http.ResponseWriter, r *http.Request) error { + aSessionID := r.Context().Value(sessionIDKey) + sessionID, ok := aSessionID.(string) + if !ok { + return ErrNoSessionInContext + } + + http.SetCookie(w, &http.Cookie{ + Name: s.config.SessionCookie, + Value: sessionID, + Path: s.config.Path, + Domain: s.config.Domain, + Secure: s.config.Secure, + HttpOnly: s.config.HttpOnly, MaxAge: -1, }) + + return s.storer.Remove(sessionID) +} + +func FromRequest(r *http.Request) *SessionManager { + return r.Context().Value(sessionManagerKey).(*SessionManager) } diff --git a/middleware/session/store.go b/middleware/session/store.go index b74a8aa..a02ba1e 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -1,18 +1,9 @@ package session -import ( - "context" - "errors" -) - -var ( - ErrSessionNotFound = errors.New("session not found") -) - type Store interface { - Load(ctx context.Context, sessionID string) Value - Save(ctx context.Context, sessionID string, value Value) error - Remove(ctx context.Context, sessionID string) error + Load(sessionID string) Values + Save(sessionID string, value Values) error + Remove(sessionID string) error } -type Value map[string]any +type Values map[string]any |