aboutsummaryrefslogtreecommitdiff
path: root/rpc/middleware/validation.go
diff options
context:
space:
mode:
Diffstat (limited to 'rpc/middleware/validation.go')
-rw-r--r--rpc/middleware/validation.go30
1 files changed, 21 insertions, 9 deletions
diff --git a/rpc/middleware/validation.go b/rpc/middleware/validation.go
index e994383..80b4cfd 100644
--- a/rpc/middleware/validation.go
+++ b/rpc/middleware/validation.go
@@ -30,25 +30,37 @@ import (
"go.neonxp.dev/jsonrpc2/rpc"
)
+type ServiceSchema map[string]MethodSchema
+
+func MustSchema(schema string) ServiceSchema {
+ ss := new(ServiceSchema)
+ if err := json.Unmarshal([]byte(schema), ss); err != nil {
+ panic(err)
+ }
+ return *ss
+}
+
type MethodSchema struct {
- Request jsonschema.Schema
- Response jsonschema.Schema
+ Request *jsonschema.Schema `json:"request"`
+ Response *jsonschema.Schema `json:"response"`
}
-func Validation(serviceSchema map[string]MethodSchema) (rpc.Middleware, error) {
+func Validation(serviceSchema ServiceSchema) (rpc.Middleware, error) {
return func(handler rpc.RpcHandler) rpc.RpcHandler {
return func(ctx context.Context, req *rpc.RpcRequest) *rpc.RpcResponse {
- if rs, ok := serviceSchema[strings.ToLower(req.Method)]; ok {
- if errResp := formatError(ctx, req.Id, rs.Request, req.Params); errResp != nil {
+ rs, hasSchema := serviceSchema[strings.ToLower(req.Method)]
+ if hasSchema && rs.Request != nil {
+ if errResp := formatError(ctx, req.Id, *rs.Request, req.Params); errResp != nil {
return errResp
}
- resp := handler(ctx, req)
- if errResp := formatError(ctx, req.Id, rs.Response, resp.Result); errResp != nil {
+ }
+ resp := handler(ctx, req)
+ if hasSchema && rs.Response != nil {
+ if errResp := formatError(ctx, req.Id, *rs.Response, resp.Result); errResp != nil {
return errResp
}
- return resp
}
- return handler(ctx, req)
+ return resp
}
}, nil
}