summaryrefslogtreecommitdiff
path: root/middleware
diff options
context:
space:
mode:
authorAlexander NeonXP Kiryukhin <i@neonxp.ru>2024-07-29 02:38:17 +0300
committerAlexander NeonXP Kiryukhin <i@neonxp.ru>2024-07-29 02:38:17 +0300
commit2916082d5ed94ef86ad58bdb7256ae07b214c4f3 (patch)
tree322a0e9172c07457a892f9737839843b8c584864 /middleware
Начальный коммит
Diffstat (limited to 'middleware')
-rw-r--r--middleware/context.go11
-rw-r--r--middleware/logger.go48
-rw-r--r--middleware/recover.go33
-rw-r--r--middleware/request_id.go35
-rw-r--r--middleware/session.go89
-rw-r--r--middleware/session/bbolt.go71
-rw-r--r--middleware/session/memstore.go31
-rw-r--r--middleware/session/store.go18
8 files changed, 336 insertions, 0 deletions
diff --git a/middleware/context.go b/middleware/context.go
new file mode 100644
index 0000000..b9ad45f
--- /dev/null
+++ b/middleware/context.go
@@ -0,0 +1,11 @@
+package middleware
+
+type ctxKey int
+
+const (
+ requestIDKey ctxKey = iota
+ sessionIDKey
+ sessionValueKey
+ sessionConfigKey
+ sessionStorerKey
+)
diff --git a/middleware/logger.go b/middleware/logger.go
new file mode 100644
index 0000000..80117da
--- /dev/null
+++ b/middleware/logger.go
@@ -0,0 +1,48 @@
+package middleware
+
+import (
+ "log/slog"
+ "net/http"
+ "time"
+
+ "go.neonxp.ru/mux"
+)
+
+type wrappedResponse struct {
+ http.ResponseWriter
+ statusCode int
+}
+
+func (w *wrappedResponse) WriteHeader(code int) {
+ w.statusCode = code
+ w.ResponseWriter.WriteHeader(code)
+}
+
+func Logger(logger *slog.Logger) mux.Middleware {
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ requestID := GetRequestID(r)
+ args := []any{
+ slog.String("proto", r.Proto),
+ slog.String("method", r.Method),
+ slog.String("request_uri", r.RequestURI),
+ slog.String("request_id", requestID),
+ }
+ logger.InfoContext(
+ r.Context(),
+ "start request",
+ args...,
+ )
+ t := time.Now()
+ wr := &wrappedResponse{ResponseWriter: w, statusCode: http.StatusOK}
+ next.ServeHTTP(wr, r)
+ args = append(args, slog.String("response_time", time.Since(t).String()))
+ args = append(args, slog.Int("response_status", wr.statusCode))
+ logger.InfoContext(
+ r.Context(),
+ "finish request",
+ args...,
+ )
+ })
+ }
+}
diff --git a/middleware/recover.go b/middleware/recover.go
new file mode 100644
index 0000000..b34d582
--- /dev/null
+++ b/middleware/recover.go
@@ -0,0 +1,33 @@
+package middleware
+
+import (
+ "log/slog"
+ "net/http"
+
+ "go.neonxp.ru/mux"
+)
+
+func Recover(logger *slog.Logger) mux.Middleware {
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ defer func() {
+ err := recover()
+ if err == nil {
+ return
+ }
+ requestID := GetRequestID(r)
+ logger.ErrorContext(
+ r.Context(),
+ "panic",
+ slog.Any("panic", err),
+ slog.String("proto", r.Proto),
+ slog.String("method", r.Method),
+ slog.String("request_uri", r.RequestURI),
+ slog.String("request_id", requestID),
+ )
+ }()
+
+ next.ServeHTTP(w, r)
+ })
+ }
+}
diff --git a/middleware/request_id.go b/middleware/request_id.go
new file mode 100644
index 0000000..016b44a
--- /dev/null
+++ b/middleware/request_id.go
@@ -0,0 +1,35 @@
+package middleware
+
+import (
+ "context"
+ "net/http"
+
+ "go.neonxp.ru/objectid"
+)
+
+const RequestIDHeader string = "X-Request-ID"
+
+func RequestID(next http.Handler) http.Handler {
+ objectid.Seed()
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ requestID := r.Header.Get(RequestIDHeader)
+ if requestID == "" {
+ requestID = objectid.New().String()
+ }
+
+ next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), requestIDKey, requestID)))
+ })
+}
+
+func GetRequestID(r *http.Request) string {
+ rid := r.Context().Value(requestIDKey)
+ if rid == nil {
+ return ""
+ }
+ srid, ok := rid.(string)
+ if !ok {
+ return ""
+ }
+
+ return srid
+}
diff --git a/middleware/session.go b/middleware/session.go
new file mode 100644
index 0000000..838e088
--- /dev/null
+++ b/middleware/session.go
@@ -0,0 +1,89 @@
+package middleware
+
+import (
+ "context"
+ "errors"
+ "net/http"
+
+ "go.neonxp.ru/mux"
+ "go.neonxp.ru/mux/middleware/session"
+ "go.neonxp.ru/objectid"
+)
+
+type SessionConfig struct {
+ SessionCookie string
+ Path string
+ Domain string
+ Secure bool
+ HttpOnly bool
+ MaxAge int
+}
+
+var DefaultSessionConfig SessionConfig = SessionConfig{
+ SessionCookie: "_session",
+ Path: "/",
+ Domain: "",
+ Secure: false,
+ HttpOnly: true,
+ MaxAge: 30 * 3600,
+}
+
+func Session(config SessionConfig, storer session.Store) mux.Middleware {
+ return func(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ var (
+ sessionID string
+ values session.Value
+ )
+ cookie, err := r.Cookie(config.SessionCookie)
+ switch {
+ case err == nil:
+ sessionID = cookie.Value
+ values = storer.Load(r.Context(), sessionID)
+ case errors.Is(err, http.ErrNoCookie):
+ sessionID = objectid.New().String()
+ values = session.Value{}
+ }
+
+ http.SetCookie(w, &http.Cookie{
+ Name: config.SessionCookie,
+ Value: sessionID,
+ Path: config.Path,
+ Domain: config.Domain,
+ Secure: config.Secure,
+ HttpOnly: config.HttpOnly,
+ MaxAge: config.MaxAge,
+ })
+
+ ctx := context.WithValue(r.Context(), sessionValueKey, &values)
+ ctx = context.WithValue(ctx, sessionIDKey, sessionID)
+ ctx = context.WithValue(ctx, sessionConfigKey, config)
+ ctx = context.WithValue(ctx, sessionStorerKey, storer)
+
+ h.ServeHTTP(w, r.WithContext(ctx))
+
+ storer.Save(r.Context(), sessionID, values)
+
+ })
+ }
+}
+
+func SessionFromRequest(r *http.Request) *session.Value {
+ return r.Context().Value(sessionValueKey).(*session.Value)
+}
+
+func ClearSession(w http.ResponseWriter, r *http.Request) {
+ storer := r.Context().Value(sessionStorerKey).(session.Store)
+ sessionID := r.Context().Value(sessionIDKey).(string)
+ storer.Remove(r.Context(), sessionID)
+ config := r.Context().Value(sessionConfigKey).(SessionConfig)
+ http.SetCookie(w, &http.Cookie{
+ Name: config.SessionCookie,
+ Value: sessionID,
+ Path: config.Path,
+ Domain: config.Domain,
+ Secure: config.Secure,
+ HttpOnly: config.HttpOnly,
+ MaxAge: -1,
+ })
+}
diff --git a/middleware/session/bbolt.go b/middleware/session/bbolt.go
new file mode 100644
index 0000000..1068ed8
--- /dev/null
+++ b/middleware/session/bbolt.go
@@ -0,0 +1,71 @@
+package session
+
+import (
+ "bytes"
+ "context"
+ "encoding/gob"
+ "log/slog"
+
+ "go.etcd.io/bbolt"
+)
+
+func New(db *bbolt.DB, bucketName []byte) Store {
+ return &BoltStore{
+ db: db,
+ bucketName: bucketName,
+ }
+}
+
+type BoltStore struct {
+ db *bbolt.DB
+ bucketName []byte
+}
+
+func (s *BoltStore) Load(ctx context.Context, sessionID string) Value {
+ v := Value{}
+ err := s.db.View(func(tx *bbolt.Tx) error {
+ bucket := tx.Bucket(s.bucketName)
+ if bucket == nil {
+ // no bucket -- normal situation
+ return nil
+ }
+ vb := bucket.Get([]byte(sessionID))
+ if vb == nil {
+ // no session -- no error
+ return nil
+ }
+ rdr := bytes.NewBuffer(vb)
+
+ return gob.NewDecoder(rdr).Decode(&v)
+ })
+ if err != nil {
+ slog.WarnContext(ctx, "failed load session", slog.Any("error", err))
+ }
+ return v
+}
+
+func (s *BoltStore) Save(ctx context.Context, sessionID string, value Value) error {
+ return s.db.Update(func(tx *bbolt.Tx) error {
+ bucket, err := tx.CreateBucketIfNotExists(s.bucketName)
+ if err != nil {
+ return err
+ }
+ wrt := bytes.NewBuffer([]byte{})
+ if err := gob.NewEncoder(wrt).Encode(value); err != nil {
+ return err
+ }
+
+ return bucket.Put([]byte(sessionID), wrt.Bytes())
+ })
+}
+
+func (s *BoltStore) Remove(ctx context.Context, sessionID string) error {
+ return s.db.Update(func(tx *bbolt.Tx) error {
+ bucket, err := tx.CreateBucketIfNotExists(s.bucketName)
+ if err != nil {
+ return err
+ }
+
+ return bucket.Delete([]byte(sessionID))
+ })
+}
diff --git a/middleware/session/memstore.go b/middleware/session/memstore.go
new file mode 100644
index 0000000..2fcef39
--- /dev/null
+++ b/middleware/session/memstore.go
@@ -0,0 +1,31 @@
+package session
+
+import (
+ "context"
+ "sync"
+)
+
+type MemoryStore struct {
+ store sync.Map
+}
+
+func (s *MemoryStore) Load(ctx context.Context, sessionID string) Value {
+ val, ok := s.store.Load(sessionID)
+ if ok {
+ return val.(Value)
+ }
+
+ return Value{}
+}
+
+func (s *MemoryStore) Save(ctx context.Context, sessionID string, value Value) error {
+ s.store.Store(sessionID, value)
+
+ return nil
+}
+
+func (s *MemoryStore) Remove(ctx context.Context, sessionID string) error {
+ s.store.Delete(sessionID)
+
+ return nil
+}
diff --git a/middleware/session/store.go b/middleware/session/store.go
new file mode 100644
index 0000000..b74a8aa
--- /dev/null
+++ b/middleware/session/store.go
@@ -0,0 +1,18 @@
+package session
+
+import (
+ "context"
+ "errors"
+)
+
+var (
+ ErrSessionNotFound = errors.New("session not found")
+)
+
+type Store interface {
+ Load(ctx context.Context, sessionID string) Value
+ Save(ctx context.Context, sessionID string, value Value) error
+ Remove(ctx context.Context, sessionID string) error
+}
+
+type Value map[string]any