diff options
Diffstat (limited to 'rpc/server.go')
-rw-r--r-- | rpc/server.go | 86 |
1 files changed, 52 insertions, 34 deletions
diff --git a/rpc/server.go b/rpc/server.go index 9c5e847..b2f6158 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -33,20 +33,29 @@ import ( const version = "2.0" type RpcServer struct { - Logger Logger - IgnoreNotifications bool - handlers map[string]Handler - transports []transport.Transport - mu sync.RWMutex + logger Logger + handlers map[string]Handler + middlewares []Middleware + transports []transport.Transport + mu sync.RWMutex } -func New() *RpcServer { - return &RpcServer{ - Logger: nopLogger{}, - IgnoreNotifications: true, - handlers: map[string]Handler{}, - transports: []transport.Transport{}, - mu: sync.RWMutex{}, +func New(opts ...Option) *RpcServer { + s := &RpcServer{ + logger: nopLogger{}, + handlers: map[string]Handler{}, + transports: []transport.Transport{}, + mu: sync.RWMutex{}, + } + for _, opt := range opts { + opt(s) + } + return s +} + +func (r *RpcServer) Use(opts ...Option) { + for _, opt := range opts { + opt(r) } } @@ -56,10 +65,6 @@ func (r *RpcServer) Register(method string, handler Handler) { r.handlers[method] = handler } -func (r *RpcServer) AddTransport(transport transport.Transport) { - r.transports = append(r.transports, transport) -} - func (r *RpcServer) Run(ctx context.Context) error { eg, ctx := errgroup.WithContext(ctx) for _, t := range r.transports { @@ -70,25 +75,27 @@ func (r *RpcServer) Run(ctx context.Context) error { return eg.Wait() } -func (r *RpcServer) Resolve(ctx context.Context, rd io.Reader, w io.Writer) { +func (r *RpcServer) Resolve(ctx context.Context, rd io.Reader, w io.Writer, parallel bool) { dec := json.NewDecoder(rd) enc := json.NewEncoder(w) mu := sync.Mutex{} wg := sync.WaitGroup{} for { - req := new(rpcRequest) + req := new(RpcRequest) if err := dec.Decode(req); err != nil { if err == io.EOF { break } - r.Logger.Logf("Can't read body: %v", err) + r.logger.Logf("Can't read body: %v", err) WriteError(ErrCodeParseError, enc) break } - wg.Add(1) - go func(req *rpcRequest) { - defer wg.Done() - resp := r.callMethod(ctx, req) + exec := func() { + h := r.callMethod + for _, m := range r.middlewares { + h = m(h) + } + resp := h(ctx, req) if req.Id == nil { // notification request return @@ -96,23 +103,34 @@ func (r *RpcServer) Resolve(ctx context.Context, rd io.Reader, w io.Writer) { mu.Lock() defer mu.Unlock() if err := enc.Encode(resp); err != nil { - r.Logger.Logf("Can't write response: %v", err) + r.logger.Logf("Can't write response: %v", err) WriteError(ErrCodeInternalError, enc) } if w, canFlush := w.(Flusher); canFlush { w.Flush() } - }(req) + } + if parallel { + wg.Add(1) + go func(req *RpcRequest) { + defer wg.Done() + exec() + }(req) + } else { + exec() + } + } + if parallel { + wg.Wait() } - wg.Wait() } -func (r *RpcServer) callMethod(ctx context.Context, req *rpcRequest) *rpcResponse { +func (r *RpcServer) callMethod(ctx context.Context, req *RpcRequest) *RpcResponse { r.mu.RLock() h, ok := r.handlers[req.Method] r.mu.RUnlock() if !ok { - return &rpcResponse{ + return &RpcResponse{ Jsonrpc: version, Error: ErrorFromCode(ErrCodeMethodNotFound), Id: req.Id, @@ -120,14 +138,14 @@ func (r *RpcServer) callMethod(ctx context.Context, req *rpcRequest) *rpcRespons } resp, err := h(ctx, req.Params) if err != nil { - r.Logger.Logf("User error %v", err) - return &rpcResponse{ + r.logger.Logf("User error %v", err) + return &RpcResponse{ Jsonrpc: version, Error: err, Id: req.Id, } } - return &rpcResponse{ + return &RpcResponse{ Jsonrpc: version, Result: resp, Id: req.Id, @@ -135,20 +153,20 @@ func (r *RpcServer) callMethod(ctx context.Context, req *rpcRequest) *rpcRespons } func WriteError(code int, enc *json.Encoder) { - enc.Encode(rpcResponse{ + enc.Encode(RpcResponse{ Jsonrpc: version, Error: ErrorFromCode(code), }) } -type rpcRequest struct { +type RpcRequest struct { Jsonrpc string `json:"jsonrpc"` Method string `json:"method"` Params json.RawMessage `json:"params"` Id any `json:"id"` } -type rpcResponse struct { +type RpcResponse struct { Jsonrpc string `json:"jsonrpc"` Result json.RawMessage `json:"result,omitempty"` Error error `json:"error,omitempty"` |