diff options
Diffstat (limited to 'rpc/server.go')
-rw-r--r-- | rpc/server.go | 89 |
1 files changed, 52 insertions, 37 deletions
diff --git a/rpc/server.go b/rpc/server.go index 4fa004d..1bb15d5 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -24,6 +24,10 @@ import ( "encoding/json" "io" "sync" + + "golang.org/x/sync/errgroup" + + "go.neonxp.dev/jsonrpc2/transport" ) const version = "2.0" @@ -32,6 +36,7 @@ type RpcServer struct { Logger Logger IgnoreNotifications bool handlers map[string]Handler + transports []transport.Transport mu sync.RWMutex } @@ -40,6 +45,7 @@ func New() *RpcServer { Logger: nopLogger{}, IgnoreNotifications: true, handlers: map[string]Handler{}, + transports: []transport.Transport{}, mu: sync.RWMutex{}, } } @@ -50,51 +56,55 @@ func (r *RpcServer) Register(method string, handler Handler) { r.handlers[method] = handler } -func (r *RpcServer) SingleRequest(ctx context.Context, reader io.Reader, writer io.Writer) { - req := new(rpcRequest) - if err := json.NewDecoder(reader).Decode(req); err != nil { - r.Logger.Logf("Can't read body: %v", err) - WriteError(ErrCodeParseError, writer) - return - } - resp := r.callMethod(ctx, req) - if req.Id == nil && r.IgnoreNotifications { - // notification request - return - } - if err := json.NewEncoder(writer).Encode(resp); err != nil { - r.Logger.Logf("Can't write response: %v", err) - WriteError(ErrCodeInternalError, writer) - return - } +func (r *RpcServer) AddTransport(transport transport.Transport) { + r.transports = append(r.transports, transport) } -func (r *RpcServer) BatchRequest(ctx context.Context, reader io.Reader, writer io.Writer) { - var req []rpcRequest - if err := json.NewDecoder(reader).Decode(&req); err != nil { - r.Logger.Logf("Can't read body: %v", err) - WriteError(ErrCodeParseError, writer) - return +func (r *RpcServer) Run(ctx context.Context) error { + eg, ctx := errgroup.WithContext(ctx) + for _, t := range r.transports { + eg.Go(func(t transport.Transport) func() error { + return func() error { return t.Run(ctx, r) } + }(t)) } - var responses []*rpcResponse + return eg.Wait() +} + +func (r *RpcServer) Resolve(ctx context.Context, rd io.Reader, w io.Writer) { + dec := json.NewDecoder(rd) + enc := json.NewEncoder(w) + mu := sync.Mutex{} wg := sync.WaitGroup{} - wg.Add(len(req)) - for _, j := range req { - go func(req rpcRequest) { + for { + req := new(rpcRequest) + if err := dec.Decode(req); err != nil { + if err == io.EOF { + break + } + r.Logger.Logf("Can't read body: %v", err) + WriteError(ErrCodeParseError, enc) + break + } + wg.Add(1) + go func(req *rpcRequest) { defer wg.Done() - resp := r.callMethod(ctx, &req) - if req.Id == nil && r.IgnoreNotifications { + resp := r.callMethod(ctx, req) + if req.Id == nil { // notification request return } - responses = append(responses, resp) - }(j) + mu.Lock() + defer mu.Unlock() + if err := enc.Encode(resp); err != nil { + r.Logger.Logf("Can't write response: %v", err) + WriteError(ErrCodeInternalError, enc) + } + if w, canFlush := w.(Flusher); canFlush { + w.Flush() + } + }(req) } wg.Wait() - if err := json.NewEncoder(writer).Encode(responses); err != nil { - r.Logger.Logf("Can't write response: %v", err) - WriteError(ErrCodeInternalError, writer) - } } func (r *RpcServer) callMethod(ctx context.Context, req *rpcRequest) *rpcResponse { @@ -124,8 +134,8 @@ func (r *RpcServer) callMethod(ctx context.Context, req *rpcRequest) *rpcRespons } } -func WriteError(code int, w io.Writer) { - _ = json.NewEncoder(w).Encode(rpcResponse{ +func WriteError(code int, enc *json.Encoder) { + enc.Encode(rpcResponse{ Jsonrpc: version, Error: NewError(code), }) @@ -144,3 +154,8 @@ type rpcResponse struct { Error error `json:"error,omitempty"` Id any `json:"id,omitempty"` } + +type Flusher interface { + // Flush sends any buffered data to the client. + Flush() +} |