summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--middleware/context.go4
-rw-r--r--middleware/session/bbolt.go11
-rw-r--r--middleware/session/context.go9
-rw-r--r--middleware/session/memstore.go11
-rw-r--r--middleware/session/session.go114
-rw-r--r--middleware/session/store.go17
6 files changed, 97 insertions, 69 deletions
diff --git a/middleware/context.go b/middleware/context.go
index a76bca0..8671140 100644
--- a/middleware/context.go
+++ b/middleware/context.go
@@ -4,8 +4,4 @@ type ctxKey int
const (
requestIDKey ctxKey = iota
- SessionIDKey
- SessionValueKey
- SessionConfigKey
- SessionStorerKey
)
diff --git a/middleware/session/bbolt.go b/middleware/session/bbolt.go
index be484d7..bf46953 100644
--- a/middleware/session/bbolt.go
+++ b/middleware/session/bbolt.go
@@ -2,7 +2,6 @@ package session
import (
"bytes"
- "context"
"encoding/gob"
"log/slog"
@@ -21,8 +20,8 @@ type BoltStore struct {
bucketName []byte
}
-func (s *BoltStore) Load(ctx context.Context, sessionID string) Value {
- v := Value{}
+func (s *BoltStore) Load(sessionID string) Values {
+ v := Values{}
err := s.db.View(func(tx *bbolt.Tx) error {
bucket := tx.Bucket(s.bucketName)
if bucket == nil {
@@ -39,12 +38,12 @@ func (s *BoltStore) Load(ctx context.Context, sessionID string) Value {
return gob.NewDecoder(rdr).Decode(&v)
})
if err != nil {
- slog.WarnContext(ctx, "failed load session", slog.Any("error", err))
+ slog.Warn("failed load session", slog.Any("error", err))
}
return v
}
-func (s *BoltStore) Save(ctx context.Context, sessionID string, value Value) error {
+func (s *BoltStore) Save(sessionID string, value Values) error {
return s.db.Update(func(tx *bbolt.Tx) error {
bucket, err := tx.CreateBucketIfNotExists(s.bucketName)
if err != nil {
@@ -59,7 +58,7 @@ func (s *BoltStore) Save(ctx context.Context, sessionID string, value Value) err
})
}
-func (s *BoltStore) Remove(ctx context.Context, sessionID string) error {
+func (s *BoltStore) Remove(sessionID string) error {
return s.db.Update(func(tx *bbolt.Tx) error {
bucket, err := tx.CreateBucketIfNotExists(s.bucketName)
if err != nil {
diff --git a/middleware/session/context.go b/middleware/session/context.go
new file mode 100644
index 0000000..870957d
--- /dev/null
+++ b/middleware/session/context.go
@@ -0,0 +1,9 @@
+package session
+
+type ctxKey int
+
+const (
+ sessionManagerKey ctxKey = iota
+ sessionIDKey
+ sessionValueKey
+)
diff --git a/middleware/session/memstore.go b/middleware/session/memstore.go
index 2fcef39..d8cb958 100644
--- a/middleware/session/memstore.go
+++ b/middleware/session/memstore.go
@@ -1,7 +1,6 @@
package session
import (
- "context"
"sync"
)
@@ -9,22 +8,22 @@ type MemoryStore struct {
store sync.Map
}
-func (s *MemoryStore) Load(ctx context.Context, sessionID string) Value {
+func (s *MemoryStore) Load(sessionID string) Values {
val, ok := s.store.Load(sessionID)
if ok {
- return val.(Value)
+ return val.(Values)
}
- return Value{}
+ return Values{}
}
-func (s *MemoryStore) Save(ctx context.Context, sessionID string, value Value) error {
+func (s *MemoryStore) Save(sessionID string, value Values) error {
s.store.Store(sessionID, value)
return nil
}
-func (s *MemoryStore) Remove(ctx context.Context, sessionID string) error {
+func (s *MemoryStore) Remove(sessionID string) error {
s.store.Delete(sessionID)
return nil
diff --git a/middleware/session/session.go b/middleware/session/session.go
index 4d02bf1..47fc0fb 100644
--- a/middleware/session/session.go
+++ b/middleware/session/session.go
@@ -4,10 +4,9 @@ import (
"context"
"errors"
"net/http"
- "sync"
+ "time"
"go.neonxp.ru/mux"
- "go.neonxp.ru/mux/middleware"
"go.neonxp.ru/objectid"
)
@@ -17,7 +16,7 @@ type Config struct {
Domain string
Secure bool
HttpOnly bool
- MaxAge int
+ MaxAge time.Duration
}
var DefaultConfig Config = Config{
@@ -26,69 +25,104 @@ var DefaultConfig Config = Config{
Domain: "",
Secure: false,
HttpOnly: true,
- MaxAge: 30 * 3600,
+ MaxAge: 365 * 24 * time.Hour,
}
-func Middleware(config Config, storer Store) mux.Middleware {
- if storer == nil {
- storer = &MemoryStore{store: sync.Map{}}
+var (
+ ErrSessionNotFound = errors.New("session not found")
+ ErrNoSessionInContext = errors.New("no session in context")
+)
+
+type SessionManager struct {
+ config *Config
+ storer Store
+}
+
+func New(storer Store) *SessionManager {
+ return NewWithConfig(&DefaultConfig, storer)
+}
+
+func NewWithConfig(config *Config, storer Store) *SessionManager {
+ return &SessionManager{
+ config: config,
+ storer: storer,
}
+}
+func (s *SessionManager) Middleware() mux.Middleware {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var (
sessionID string
- values Value
+ values Values
)
- cookie, err := r.Cookie(config.SessionCookie)
+ cookie, err := r.Cookie(s.config.SessionCookie)
switch {
case err == nil:
sessionID = cookie.Value
- values = storer.Load(r.Context(), sessionID)
+ values = s.storer.Load(sessionID)
case errors.Is(err, http.ErrNoCookie):
sessionID = objectid.New().String()
- values = Value{}
}
- http.SetCookie(w, &http.Cookie{
- Name: config.SessionCookie,
- Value: sessionID,
- Path: config.Path,
- Domain: config.Domain,
- Secure: config.Secure,
- HttpOnly: config.HttpOnly,
- MaxAge: config.MaxAge,
- })
-
- ctx := context.WithValue(r.Context(), middleware.SessionValueKey, &values)
- ctx = context.WithValue(ctx, middleware.SessionIDKey, sessionID)
- ctx = context.WithValue(ctx, middleware.SessionConfigKey, config)
- ctx = context.WithValue(ctx, middleware.SessionStorerKey, storer)
+ ctx := context.WithValue(r.Context(), sessionManagerKey, s)
+ ctx = context.WithValue(ctx, sessionIDKey, sessionID)
+ ctx = context.WithValue(ctx, sessionValueKey, values)
h.ServeHTTP(w, r.WithContext(ctx))
-
- storer.Save(r.Context(), sessionID, values)
-
})
}
}
-func FromRequest(r *http.Request) *Value {
- return r.Context().Value(middleware.SessionValueKey).(*Value)
+func (s *SessionManager) Values(ctx context.Context) Values {
+ aValue := ctx.Value(sessionValueKey)
+ values, ok := aValue.(Values)
+ if !ok || values == nil {
+ values = Values{}
+ }
+
+ return values
}
-func Clear(w http.ResponseWriter, r *http.Request) {
- storer := r.Context().Value(middleware.SessionStorerKey).(Store)
- sessionID := r.Context().Value(middleware.SessionIDKey).(string)
- storer.Remove(r.Context(), sessionID)
- config := r.Context().Value(middleware.SessionConfigKey).(Config)
+func (s *SessionManager) Save(w http.ResponseWriter, r *http.Request, values Values) error {
+ aSessionID := r.Context().Value(sessionIDKey)
+ sessionID, ok := aSessionID.(string)
+ if !ok {
+ return ErrNoSessionInContext
+ }
+
http.SetCookie(w, &http.Cookie{
- Name: config.SessionCookie,
+ Name: s.config.SessionCookie,
Value: sessionID,
- Path: config.Path,
- Domain: config.Domain,
- Secure: config.Secure,
- HttpOnly: config.HttpOnly,
+ Path: s.config.Path,
+ Domain: s.config.Domain,
+ Secure: s.config.Secure,
+ HttpOnly: s.config.HttpOnly,
+ MaxAge: int(s.config.MaxAge.Seconds()),
+ })
+
+ return s.storer.Save(sessionID, values)
+}
+func (s *SessionManager) Clear(w http.ResponseWriter, r *http.Request) error {
+ aSessionID := r.Context().Value(sessionIDKey)
+ sessionID, ok := aSessionID.(string)
+ if !ok {
+ return ErrNoSessionInContext
+ }
+
+ http.SetCookie(w, &http.Cookie{
+ Name: s.config.SessionCookie,
+ Value: sessionID,
+ Path: s.config.Path,
+ Domain: s.config.Domain,
+ Secure: s.config.Secure,
+ HttpOnly: s.config.HttpOnly,
MaxAge: -1,
})
+
+ return s.storer.Remove(sessionID)
+}
+
+func FromRequest(r *http.Request) *SessionManager {
+ return r.Context().Value(sessionManagerKey).(*SessionManager)
}
diff --git a/middleware/session/store.go b/middleware/session/store.go
index b74a8aa..a02ba1e 100644
--- a/middleware/session/store.go
+++ b/middleware/session/store.go
@@ -1,18 +1,9 @@
package session
-import (
- "context"
- "errors"
-)
-
-var (
- ErrSessionNotFound = errors.New("session not found")
-)
-
type Store interface {
- Load(ctx context.Context, sessionID string) Value
- Save(ctx context.Context, sessionID string, value Value) error
- Remove(ctx context.Context, sessionID string) error
+ Load(sessionID string) Values
+ Save(sessionID string, value Values) error
+ Remove(sessionID string) error
}
-type Value map[string]any
+type Values map[string]any