diff --git a/cpc/cpc.go b/cpc/cpc.go deleted file mode 100644 index dbe9bc29..00000000 --- a/cpc/cpc.go +++ /dev/null @@ -1,159 +0,0 @@ -// Copyright (c) 2023 Proton AG -// -// This file is part of Proton Mail Bridge. -// -// Proton Mail Bridge is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// Proton Mail Bridge is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with Proton Mail Bridge. If not, see . - -package cpc - -import ( - "context" - "errors" - "fmt" -) - -var ErrRequestHasNoReply = errors.New("request has no reply channel") -var ErrExpectedReply = errors.New("request does not have reply channel") - -// Utilities to implement Chanel Procedure Calls. Similar in concept to RPC, but with between go-routines. - -type RequestReply struct { - Value any - Error error -} - -type Request struct { - Value any - Reply chan RequestReply -} - -func NewRequest(value any) *Request { - return &Request{ - Value: value, - Reply: make(chan RequestReply), - } -} - -func NewRequestWithoutReply(value any) *Request { - return &Request{ - Value: value, - Reply: nil, - } -} - -func (r *Request) Close() { - if r.Reply != nil { - panic("request reply has not been sent") - } -} - -func (r *Request) SendReply(ctx context.Context, value any, err error) { - if r.Reply == nil { - panic("request has no reply") - } - - defer func() { - close(r.Reply) - r.Reply = nil - }() - - select { - case <-ctx.Done(): - case r.Reply <- RequestReply{ - Value: value, - Error: err, - }: - } -} - -type CPC struct { - request chan *Request -} - -func NewCPC() *CPC { - return &CPC{ - request: make(chan *Request), - } -} - -// Receive is meant to be called by the code that is supposed to handle the requests that arrive. -func (c *CPC) Receive(ctx context.Context, f func(context.Context, *Request)) { - for request := range c.request { - f(ctx, request) - request.Close() - } -} - -func (c *CPC) ReceiveCh() <-chan *Request { - return c.request -} - -func (c *CPC) Close() { - close(c.request) -} - -// SendNoReply sends a request which doesn't expect a reply. -func (c *CPC) SendNoReply(ctx context.Context, value any) error { - return c.executeNoReplyImpl(ctx, NewRequestWithoutReply(value)) -} - -// SendWithReply sends a request which expects a reply. -func (c *CPC) SendWithReply(ctx context.Context, value any) (any, error) { - return c.executeReplyImpl(ctx, NewRequest(value)) -} - -func SendWithReplyType[T any](ctx context.Context, c *CPC, value any) (T, error) { - val, err := c.executeReplyImpl(ctx, NewRequest(value)) - if err != nil { - var t T - return t, err - } - - switch vt := val.(type) { - case T: - return vt, nil - default: - var t T - return t, fmt.Errorf("reply type does not match") - } -} - -func (c *CPC) executeNoReplyImpl(ctx context.Context, request *Request) error { - select { - case <-ctx.Done(): - return ctx.Err() - case c.request <- request: - } - - return nil -} - -func (c *CPC) executeReplyImpl(ctx context.Context, request *Request) (any, error) { - if request.Reply == nil { - return nil, ErrExpectedReply - } - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case c.request <- request: - } - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case reply := <-request.Reply: - return reply.Value, reply.Error - } -} diff --git a/internal/bridge/server_manager.go b/internal/bridge/server_manager.go index acca55c8..9e06a430 100644 --- a/internal/bridge/server_manager.go +++ b/internal/bridge/server_manager.go @@ -26,10 +26,10 @@ import ( "github.com/ProtonMail/gluon" "github.com/ProtonMail/gluon/connector" "github.com/ProtonMail/gluon/logging" - "github.com/ProtonMail/proton-bridge/v3/cpc" "github.com/ProtonMail/proton-bridge/v3/internal/events" "github.com/ProtonMail/proton-bridge/v3/internal/safe" "github.com/ProtonMail/proton-bridge/v3/internal/user" + "github.com/ProtonMail/proton-bridge/v3/pkg/cpc" "github.com/emersion/go-smtp" "github.com/sirupsen/logrus" ) @@ -77,31 +77,31 @@ func (sm *ServerManager) Init(bridge *Bridge) error { func (sm *ServerManager) CloseServers(ctx context.Context) error { defer sm.requests.Close() - _, err := sm.requests.SendWithReply(ctx, &smRequestClose{}) + _, err := sm.requests.Send(ctx, &smRequestClose{}) return err } func (sm *ServerManager) RestartIMAP(ctx context.Context) error { - _, err := sm.requests.SendWithReply(ctx, &smRequestRestartIMAP{}) + _, err := sm.requests.Send(ctx, &smRequestRestartIMAP{}) return err } func (sm *ServerManager) RestartSMTP(ctx context.Context) error { - _, err := sm.requests.SendWithReply(ctx, &smRequestRestartSMTP{}) + _, err := sm.requests.Send(ctx, &smRequestRestartSMTP{}) return err } func (sm *ServerManager) AddIMAPUser(ctx context.Context, user *user.User) error { - _, err := sm.requests.SendWithReply(ctx, &smRequestAddIMAPUser{user: user}) + _, err := sm.requests.Send(ctx, &smRequestAddIMAPUser{user: user}) return err } func (sm *ServerManager) RemoveIMAPUser(ctx context.Context, user *user.User, withData bool) error { - _, err := sm.requests.SendWithReply(ctx, &smRequestRemoveIMAPUser{ + _, err := sm.requests.Send(ctx, &smRequestRemoveIMAPUser{ user: user, withData: withData, }) @@ -110,7 +110,7 @@ func (sm *ServerManager) RemoveIMAPUser(ctx context.Context, user *user.User, wi } func (sm *ServerManager) SetGluonDir(ctx context.Context, gluonDir string) error { - _, err := sm.requests.SendWithReply(ctx, &smRequestSetGluonDir{ + _, err := sm.requests.Send(ctx, &smRequestSetGluonDir{ dir: gluonDir, }) @@ -118,7 +118,7 @@ func (sm *ServerManager) SetGluonDir(ctx context.Context, gluonDir string) error } func (sm *ServerManager) AddGluonUser(ctx context.Context, conn connector.Connector, passphrase []byte) (string, error) { - reply, err := cpc.SendWithReplyType[string](ctx, sm.requests, &smRequestAddGluonUser{ + reply, err := cpc.SendTyped[string](ctx, sm.requests, &smRequestAddGluonUser{ conn: conn, passphrase: passphrase, }) @@ -127,7 +127,7 @@ func (sm *ServerManager) AddGluonUser(ctx context.Context, conn connector.Connec } func (sm *ServerManager) RemoveGluonUser(ctx context.Context, gluonID string) error { - _, err := sm.requests.SendWithReply(ctx, &smRequestRemoveGluonUser{ + _, err := sm.requests.Send(ctx, &smRequestRemoveGluonUser{ userID: gluonID, }) @@ -165,23 +165,23 @@ func (sm *ServerManager) run(ctx context.Context, bridge *Bridge) { return } - switch r := request.Value.(type) { + switch r := request.Value().(type) { case *smRequestClose: sm.handleClose(ctx, bridge) - request.SendReply(ctx, nil, nil) + request.Reply(ctx, nil, nil) return case *smRequestRestartSMTP: err := sm.restartSMTP(bridge) - request.SendReply(ctx, nil, err) + request.Reply(ctx, nil, err) case *smRequestRestartIMAP: err := sm.restartIMAP(ctx, bridge) - request.SendReply(ctx, nil, err) + request.Reply(ctx, nil, err) case *smRequestAddIMAPUser: err := sm.handleAddIMAPUser(ctx, r.user) - request.SendReply(ctx, nil, err) + request.Reply(ctx, nil, err) if err == nil { sm.loadedUserCount++ sm.handleLoadedUserCountChange(ctx, bridge) @@ -189,7 +189,7 @@ func (sm *ServerManager) run(ctx context.Context, bridge *Bridge) { case *smRequestRemoveIMAPUser: err := sm.handleRemoveIMAPUser(ctx, r.user, r.withData) - request.SendReply(ctx, nil, err) + request.Reply(ctx, nil, err) if err == nil { sm.loadedUserCount-- sm.handleLoadedUserCountChange(ctx, bridge) @@ -197,15 +197,15 @@ func (sm *ServerManager) run(ctx context.Context, bridge *Bridge) { case *smRequestSetGluonDir: err := sm.handleSetGluonDir(ctx, bridge, r.dir) - request.SendReply(ctx, nil, err) + request.Reply(ctx, nil, err) case *smRequestAddGluonUser: id, err := sm.handleAddGluonUser(ctx, r.conn, r.passphrase) - request.SendReply(ctx, id, err) + request.Reply(ctx, id, err) case *smRequestRemoveGluonUser: err := sm.handleRemoveGluonUser(ctx, r.userID) - request.SendReply(ctx, nil, err) + request.Reply(ctx, nil, err) } } } diff --git a/pkg/cpc/cpc.go b/pkg/cpc/cpc.go new file mode 100644 index 00000000..0892c28b --- /dev/null +++ b/pkg/cpc/cpc.go @@ -0,0 +1,129 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package cpc + +import ( + "context" + "errors" +) + +var ErrInvalidReplyType = errors.New("reply type does not match") + +// Utilities to implement Chanel Procedure Calls. Similar in concept to RPC, but with between go-routines. + +// Request contains the data for a request as well as the means to reply to a request. +type Request struct { + value any + reply chan reply +} + +// Value returns the request value. +func (r *Request) Value() any { + return r.value +} + +// Reply should be used to send a reply to a given request. +func (r *Request) Reply(ctx context.Context, value any, err error) { + defer close(r.reply) + + select { + case <-ctx.Done(): + case r.reply <- reply{ + value: value, + error: err, + }: + } +} + +// CPC Channel Procedure Call. A play on RPC, but with channels. Use this type to send requests and wait for replies +// from a goroutine. +type CPC struct { + request chan *Request +} + +func NewCPC() *CPC { + return &CPC{ + request: make(chan *Request), + } +} + +// Receive invokes the function on all the request that arrive. +func (c *CPC) Receive(ctx context.Context, f func(context.Context, *Request)) { + for request := range c.request { + f(ctx, request) + } +} + +// ReceiveCh returns the channel on which all requests are sent. +func (c *CPC) ReceiveCh() <-chan *Request { + return c.request +} + +// Close closes the CPC channel and no further requests should be made. +func (c *CPC) Close() { + close(c.request) +} + +// Send sends a request which expects a reply. +func (c *CPC) Send(ctx context.Context, value any) (any, error) { + return c.execute(ctx, newRequest(value)) +} + +// SendTyped is similar to CPC.Send, but ensure that reply is of the given Type T. +func SendTyped[T any](ctx context.Context, c *CPC, value any) (T, error) { + val, err := c.execute(ctx, newRequest(value)) + if err != nil { + var t T + return t, err + } + + switch vt := val.(type) { + case T: + return vt, nil + default: + var t T + return t, ErrInvalidReplyType + } +} + +type reply struct { + value any + error error +} + +func (c *CPC) execute(ctx context.Context, request *Request) (any, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case c.request <- request: + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case r := <-request.reply: + return r.value, r.error + } +} + +func newRequest(value any) *Request { + return &Request{ + value: value, + reply: make(chan reply), + } +} diff --git a/cpc/cpc_test.go b/pkg/cpc/cpc_test.go similarity index 83% rename from cpc/cpc_test.go rename to pkg/cpc/cpc_test.go index 8fc70406..102970e5 100644 --- a/cpc/cpc_test.go +++ b/pkg/cpc/cpc_test.go @@ -42,22 +42,24 @@ func TestCPC_Receive(t *testing.T) { wg.Add(1) cpc.Receive(context.Background(), func(ctx context.Context, request *Request) { - switch request.Value.(type) { + switch request.Value().(type) { case sendIntRequest: - request.SendReply(ctx, replyValue, nil) + request.Reply(ctx, replyValue, nil) case quitRequest: - cpc.Close() + request.Reply(ctx, nil, nil) default: panic("unknown request") } }) }() - r, err := cpc.SendWithReply(context.Background(), sendIntRequest{}) + r, err := cpc.Send(context.Background(), sendIntRequest{}) require.NoError(t, err) require.Equal(t, r, replyValue) - require.NoError(t, cpc.SendNoReply(context.Background(), quitRequest{})) + _, err = cpc.Send(context.Background(), quitRequest{}) + require.NoError(t, err) + cpc.Close() wg.Wait() }