diff options
Diffstat (limited to 'rpc/middleware/validation.go')
-rw-r--r-- | rpc/middleware/validation.go | 30 |
1 files changed, 21 insertions, 9 deletions
diff --git a/rpc/middleware/validation.go b/rpc/middleware/validation.go index e994383..80b4cfd 100644 --- a/rpc/middleware/validation.go +++ b/rpc/middleware/validation.go @@ -30,25 +30,37 @@ import ( "go.neonxp.dev/jsonrpc2/rpc" ) +type ServiceSchema map[string]MethodSchema + +func MustSchema(schema string) ServiceSchema { + ss := new(ServiceSchema) + if err := json.Unmarshal([]byte(schema), ss); err != nil { + panic(err) + } + return *ss +} + type MethodSchema struct { - Request jsonschema.Schema - Response jsonschema.Schema + Request *jsonschema.Schema `json:"request"` + Response *jsonschema.Schema `json:"response"` } -func Validation(serviceSchema map[string]MethodSchema) (rpc.Middleware, error) { +func Validation(serviceSchema ServiceSchema) (rpc.Middleware, error) { return func(handler rpc.RpcHandler) rpc.RpcHandler { return func(ctx context.Context, req *rpc.RpcRequest) *rpc.RpcResponse { - if rs, ok := serviceSchema[strings.ToLower(req.Method)]; ok { - if errResp := formatError(ctx, req.Id, rs.Request, req.Params); errResp != nil { + rs, hasSchema := serviceSchema[strings.ToLower(req.Method)] + if hasSchema && rs.Request != nil { + if errResp := formatError(ctx, req.Id, *rs.Request, req.Params); errResp != nil { return errResp } - resp := handler(ctx, req) - if errResp := formatError(ctx, req.Id, rs.Response, resp.Result); errResp != nil { + } + resp := handler(ctx, req) + if hasSchema && rs.Response != nil { + if errResp := formatError(ctx, req.Id, *rs.Response, resp.Result); errResp != nil { return errResp } - return resp } - return handler(ctx, req) + return resp } }, nil } |