summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--go.mod5
-rw-r--r--go.sum2
-rw-r--r--middleware/logger.go29
-rw-r--r--middleware/recover.go34
-rw-r--r--middleware/request_id.go40
-rw-r--r--middleware/use.go13
-rw-r--r--wrap.go93
-rw-r--r--wrap_test.go45
8 files changed, 261 insertions, 0 deletions
diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..d4c707e
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,5 @@
+module go.neonxp.ru/muxtool
+
+go 1.22.3
+
+require go.neonxp.ru/objectid v0.0.2
diff --git a/go.sum b/go.sum
new file mode 100644
index 0000000..3ce48c1
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,2 @@
+go.neonxp.ru/objectid v0.0.2 h1:Z/G6zvBxmUq0NTq681oGH8pTbBWwi6VA22YOYludIPs=
+go.neonxp.ru/objectid v0.0.2/go.mod h1:s0dRi//oe1liiKcor1KmWx09WzkD6Wtww8ZaIv+VLBs=
diff --git a/middleware/logger.go b/middleware/logger.go
new file mode 100644
index 0000000..039bd19
--- /dev/null
+++ b/middleware/logger.go
@@ -0,0 +1,29 @@
+package middleware
+
+import (
+ "net/http"
+
+ "log/slog"
+)
+
+func Logger(logger *slog.Logger) Middleware {
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ next.ServeHTTP(w, r)
+ requestID := GetRequestID(r)
+ args := []any{
+ slog.String("proto", r.Proto),
+ slog.String("method", r.Method),
+ slog.String("request_uri", r.RequestURI),
+ }
+ if requestID != "" {
+ args = append(args, slog.String("request_id", requestID))
+ }
+ logger.InfoContext(
+ r.Context(),
+ "request",
+ args...,
+ )
+ })
+ }
+}
diff --git a/middleware/recover.go b/middleware/recover.go
new file mode 100644
index 0000000..6b5f2cb
--- /dev/null
+++ b/middleware/recover.go
@@ -0,0 +1,34 @@
+package middleware
+
+import (
+ "net/http"
+ "runtime/debug"
+
+ "log/slog"
+)
+
+func Recover(logger *slog.Logger) Middleware {
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ defer func() {
+ err := recover()
+ if err == nil {
+ return
+ }
+ debug.PrintStack()
+ requestID := GetRequestID(r)
+ logger.ErrorContext(
+ r.Context(),
+ "panic",
+ slog.Any("panic", err),
+ slog.String("proto", r.Proto),
+ slog.String("method", r.Method),
+ slog.String("request_uri", r.RequestURI),
+ slog.String("request_id", requestID),
+ )
+ }()
+
+ next.ServeHTTP(w, r)
+ })
+ }
+}
diff --git a/middleware/request_id.go b/middleware/request_id.go
new file mode 100644
index 0000000..0e9a521
--- /dev/null
+++ b/middleware/request_id.go
@@ -0,0 +1,40 @@
+package middleware
+
+import (
+ "context"
+ "net/http"
+
+ "go.neonxp.ru/objectid"
+)
+
+type ctxKeyRequestID int
+
+const (
+ RequestIDKey ctxKeyRequestID = 0
+ RequestIDHeader string = "X-Request-ID"
+)
+
+func RequestID(next http.Handler) http.Handler {
+ objectid.Seed()
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ requestID := r.Header.Get(RequestIDHeader)
+ if requestID == "" {
+ requestID = objectid.New().String()
+ }
+
+ next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), RequestIDKey, requestID)))
+ })
+}
+
+func GetRequestID(r *http.Request) string {
+ rid := r.Context().Value(RequestIDKey)
+ if rid == nil {
+ return ""
+ }
+ srid, ok := rid.(string)
+ if !ok {
+ return ""
+ }
+
+ return srid
+}
diff --git a/middleware/use.go b/middleware/use.go
new file mode 100644
index 0000000..6610e2f
--- /dev/null
+++ b/middleware/use.go
@@ -0,0 +1,13 @@
+package middleware
+
+import "net/http"
+
+type Middleware func(http.Handler) http.Handler
+
+func Use(handler http.Handler, middlewares ...Middleware) http.Handler {
+ for _, h := range middlewares {
+ handler = h(handler)
+ }
+
+ return handler
+}
diff --git a/wrap.go b/wrap.go
new file mode 100644
index 0000000..e8ae72b
--- /dev/null
+++ b/wrap.go
@@ -0,0 +1,93 @@
+package muxtool
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+)
+
+// Wrap API handler and returns standard http.HandlerFunc function
+func Wrap[RQ any, RS any](handler func(ctx context.Context, request *RQ) (RS, error)) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ req := new(RQ)
+ richifyRequest(req, r)
+ switch r.Method {
+ case http.MethodPost, http.MethodPatch, http.MethodDelete, http.MethodPut:
+ if err := json.NewDecoder(r.Body).Decode(req); err != nil {
+ w.WriteHeader(http.StatusBadRequest)
+ _, _ = w.Write([]byte(err.Error()))
+ return
+ }
+ }
+ resp, err := handler(r.Context(), req)
+ if err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ _, _ = w.Write([]byte(err.Error()))
+ return
+ }
+
+ statusCode := http.StatusOK
+ contentType := "application/json"
+ var body []byte
+
+ if v, ok := (any)(resp).(WithContentType); ok {
+ contentType = v.ContentType()
+ }
+ if v, ok := (any)(resp).(WithHTTPStatus); ok {
+ statusCode = v.Status()
+ }
+ if v, ok := (any)(resp).(Renderer); ok {
+ body, err = v.Render()
+ } else {
+ body, err = json.Marshal(resp)
+ }
+ if err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ _, _ = w.Write([]byte(err.Error()))
+ return
+ }
+ w.WriteHeader(statusCode)
+ w.Header().Set("Content-Type", contentType)
+ w.Write(body)
+ }
+}
+
+func richifyRequest[RQ any](req *RQ, baseRequest *http.Request) {
+ if v, ok := (any)(req).(WithHeader); ok {
+ v.WithHeader(baseRequest.Header)
+ }
+ if v, ok := (any)(req).(WithMethod); ok {
+ v.WithMethod(baseRequest.Method)
+ }
+}
+
+type NilRequest struct{}
+
+// Optional interfaces for request type
+
+// WithHeader sets headers to request
+type WithHeader interface {
+ WithHeader(header http.Header)
+}
+
+// WithMethod sets method to request
+type WithMethod interface {
+ WithMethod(method string)
+}
+
+// Optional interfaces for response type
+
+// Renderer renders response to byte slice
+type Renderer interface {
+ Render() ([]byte, error)
+}
+
+// WithContentType returns custom content type for response
+type WithContentType interface {
+ ContentType() string
+}
+
+// WithHTTPStatus returns custom status code
+type WithHTTPStatus interface {
+ Status() int
+}
diff --git a/wrap_test.go b/wrap_test.go
new file mode 100644
index 0000000..31b831e
--- /dev/null
+++ b/wrap_test.go
@@ -0,0 +1,45 @@
+package muxtool
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+)
+
+func ExampleWrap() {
+ rr := httptest.NewRecorder()
+
+ // Sample request
+ req := reqHello{
+ Name: "NeonXP",
+ }
+ b, _ := json.Marshal(req)
+ request, _ := http.NewRequest(http.MethodPost, "/hello", bytes.NewReader(b))
+
+ // Handler
+ mux := http.NewServeMux()
+ // Handle wrapped `handleHello(context.Context, *reqHello) (*respHello, error)`
+ mux.Handle("POST /hello", Wrap(handleHello))
+
+ mux.ServeHTTP(rr, request)
+
+ fmt.Println(rr.Body.String())
+ // Output: {"message":"Hello, NeonXP!"}
+}
+
+type reqHello struct {
+ Name string `json:"name"`
+}
+
+type respHello struct {
+ Message string `json:"message"`
+}
+
+func handleHello(ctx context.Context, req *reqHello) (*respHello, error) {
+ return &respHello{
+ Message: fmt.Sprintf("Hello, %s!", req.Name),
+ }, nil
+}