aboutsummaryrefslogtreecommitdiff
path: root/rpc/server.go
diff options
context:
space:
mode:
authorAlexander Kiryukhin <a.kiryukhin@mail.ru>2022-05-22 16:37:48 +0300
committerAlexander Kiryukhin <a.kiryukhin@mail.ru>2022-05-22 16:37:48 +0300
commit4a81eff217c40c459c9a9ed4f318b4dd9bc5ee8a (patch)
tree35e6d3b1f80af80fddedb26d5543377931abfe2f /rpc/server.go
parentc74596c6a6a741e3365a2f372de6e7cdf2583fdc (diff)
Middlewares and optionsv1.1.0
Diffstat (limited to 'rpc/server.go')
-rw-r--r--rpc/server.go86
1 files changed, 52 insertions, 34 deletions
diff --git a/rpc/server.go b/rpc/server.go
index 9c5e847..b2f6158 100644
--- a/rpc/server.go
+++ b/rpc/server.go
@@ -33,20 +33,29 @@ import (
const version = "2.0"
type RpcServer struct {
- Logger Logger
- IgnoreNotifications bool
- handlers map[string]Handler
- transports []transport.Transport
- mu sync.RWMutex
+ logger Logger
+ handlers map[string]Handler
+ middlewares []Middleware
+ transports []transport.Transport
+ mu sync.RWMutex
}
-func New() *RpcServer {
- return &RpcServer{
- Logger: nopLogger{},
- IgnoreNotifications: true,
- handlers: map[string]Handler{},
- transports: []transport.Transport{},
- mu: sync.RWMutex{},
+func New(opts ...Option) *RpcServer {
+ s := &RpcServer{
+ logger: nopLogger{},
+ handlers: map[string]Handler{},
+ transports: []transport.Transport{},
+ mu: sync.RWMutex{},
+ }
+ for _, opt := range opts {
+ opt(s)
+ }
+ return s
+}
+
+func (r *RpcServer) Use(opts ...Option) {
+ for _, opt := range opts {
+ opt(r)
}
}
@@ -56,10 +65,6 @@ func (r *RpcServer) Register(method string, handler Handler) {
r.handlers[method] = handler
}
-func (r *RpcServer) AddTransport(transport transport.Transport) {
- r.transports = append(r.transports, transport)
-}
-
func (r *RpcServer) Run(ctx context.Context) error {
eg, ctx := errgroup.WithContext(ctx)
for _, t := range r.transports {
@@ -70,25 +75,27 @@ func (r *RpcServer) Run(ctx context.Context) error {
return eg.Wait()
}
-func (r *RpcServer) Resolve(ctx context.Context, rd io.Reader, w io.Writer) {
+func (r *RpcServer) Resolve(ctx context.Context, rd io.Reader, w io.Writer, parallel bool) {
dec := json.NewDecoder(rd)
enc := json.NewEncoder(w)
mu := sync.Mutex{}
wg := sync.WaitGroup{}
for {
- req := new(rpcRequest)
+ req := new(RpcRequest)
if err := dec.Decode(req); err != nil {
if err == io.EOF {
break
}
- r.Logger.Logf("Can't read body: %v", err)
+ r.logger.Logf("Can't read body: %v", err)
WriteError(ErrCodeParseError, enc)
break
}
- wg.Add(1)
- go func(req *rpcRequest) {
- defer wg.Done()
- resp := r.callMethod(ctx, req)
+ exec := func() {
+ h := r.callMethod
+ for _, m := range r.middlewares {
+ h = m(h)
+ }
+ resp := h(ctx, req)
if req.Id == nil {
// notification request
return
@@ -96,23 +103,34 @@ func (r *RpcServer) Resolve(ctx context.Context, rd io.Reader, w io.Writer) {
mu.Lock()
defer mu.Unlock()
if err := enc.Encode(resp); err != nil {
- r.Logger.Logf("Can't write response: %v", err)
+ r.logger.Logf("Can't write response: %v", err)
WriteError(ErrCodeInternalError, enc)
}
if w, canFlush := w.(Flusher); canFlush {
w.Flush()
}
- }(req)
+ }
+ if parallel {
+ wg.Add(1)
+ go func(req *RpcRequest) {
+ defer wg.Done()
+ exec()
+ }(req)
+ } else {
+ exec()
+ }
+ }
+ if parallel {
+ wg.Wait()
}
- wg.Wait()
}
-func (r *RpcServer) callMethod(ctx context.Context, req *rpcRequest) *rpcResponse {
+func (r *RpcServer) callMethod(ctx context.Context, req *RpcRequest) *RpcResponse {
r.mu.RLock()
h, ok := r.handlers[req.Method]
r.mu.RUnlock()
if !ok {
- return &rpcResponse{
+ return &RpcResponse{
Jsonrpc: version,
Error: ErrorFromCode(ErrCodeMethodNotFound),
Id: req.Id,
@@ -120,14 +138,14 @@ func (r *RpcServer) callMethod(ctx context.Context, req *rpcRequest) *rpcRespons
}
resp, err := h(ctx, req.Params)
if err != nil {
- r.Logger.Logf("User error %v", err)
- return &rpcResponse{
+ r.logger.Logf("User error %v", err)
+ return &RpcResponse{
Jsonrpc: version,
Error: err,
Id: req.Id,
}
}
- return &rpcResponse{
+ return &RpcResponse{
Jsonrpc: version,
Result: resp,
Id: req.Id,
@@ -135,20 +153,20 @@ func (r *RpcServer) callMethod(ctx context.Context, req *rpcRequest) *rpcRespons
}
func WriteError(code int, enc *json.Encoder) {
- enc.Encode(rpcResponse{
+ enc.Encode(RpcResponse{
Jsonrpc: version,
Error: ErrorFromCode(code),
})
}
-type rpcRequest struct {
+type RpcRequest struct {
Jsonrpc string `json:"jsonrpc"`
Method string `json:"method"`
Params json.RawMessage `json:"params"`
Id any `json:"id"`
}
-type rpcResponse struct {
+type RpcResponse struct {
Jsonrpc string `json:"jsonrpc"`
Result json.RawMessage `json:"result,omitempty"`
Error error `json:"error,omitempty"`