diff --git a/cpc/cpc.go b/cpc/cpc.go
deleted file mode 100644
index dbe9bc29..00000000
--- a/cpc/cpc.go
+++ /dev/null
@@ -1,159 +0,0 @@
-// Copyright (c) 2023 Proton AG
-//
-// This file is part of Proton Mail Bridge.
-//
-// Proton Mail Bridge is free software: you can redistribute it and/or modify
-// it under the terms of the GNU General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// Proton Mail Bridge is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU General Public License for more details.
-//
-// You should have received a copy of the GNU General Public License
-// along with Proton Mail Bridge. If not, see .
-
-package cpc
-
-import (
- "context"
- "errors"
- "fmt"
-)
-
-var ErrRequestHasNoReply = errors.New("request has no reply channel")
-var ErrExpectedReply = errors.New("request does not have reply channel")
-
-// Utilities to implement Chanel Procedure Calls. Similar in concept to RPC, but with between go-routines.
-
-type RequestReply struct {
- Value any
- Error error
-}
-
-type Request struct {
- Value any
- Reply chan RequestReply
-}
-
-func NewRequest(value any) *Request {
- return &Request{
- Value: value,
- Reply: make(chan RequestReply),
- }
-}
-
-func NewRequestWithoutReply(value any) *Request {
- return &Request{
- Value: value,
- Reply: nil,
- }
-}
-
-func (r *Request) Close() {
- if r.Reply != nil {
- panic("request reply has not been sent")
- }
-}
-
-func (r *Request) SendReply(ctx context.Context, value any, err error) {
- if r.Reply == nil {
- panic("request has no reply")
- }
-
- defer func() {
- close(r.Reply)
- r.Reply = nil
- }()
-
- select {
- case <-ctx.Done():
- case r.Reply <- RequestReply{
- Value: value,
- Error: err,
- }:
- }
-}
-
-type CPC struct {
- request chan *Request
-}
-
-func NewCPC() *CPC {
- return &CPC{
- request: make(chan *Request),
- }
-}
-
-// Receive is meant to be called by the code that is supposed to handle the requests that arrive.
-func (c *CPC) Receive(ctx context.Context, f func(context.Context, *Request)) {
- for request := range c.request {
- f(ctx, request)
- request.Close()
- }
-}
-
-func (c *CPC) ReceiveCh() <-chan *Request {
- return c.request
-}
-
-func (c *CPC) Close() {
- close(c.request)
-}
-
-// SendNoReply sends a request which doesn't expect a reply.
-func (c *CPC) SendNoReply(ctx context.Context, value any) error {
- return c.executeNoReplyImpl(ctx, NewRequestWithoutReply(value))
-}
-
-// SendWithReply sends a request which expects a reply.
-func (c *CPC) SendWithReply(ctx context.Context, value any) (any, error) {
- return c.executeReplyImpl(ctx, NewRequest(value))
-}
-
-func SendWithReplyType[T any](ctx context.Context, c *CPC, value any) (T, error) {
- val, err := c.executeReplyImpl(ctx, NewRequest(value))
- if err != nil {
- var t T
- return t, err
- }
-
- switch vt := val.(type) {
- case T:
- return vt, nil
- default:
- var t T
- return t, fmt.Errorf("reply type does not match")
- }
-}
-
-func (c *CPC) executeNoReplyImpl(ctx context.Context, request *Request) error {
- select {
- case <-ctx.Done():
- return ctx.Err()
- case c.request <- request:
- }
-
- return nil
-}
-
-func (c *CPC) executeReplyImpl(ctx context.Context, request *Request) (any, error) {
- if request.Reply == nil {
- return nil, ErrExpectedReply
- }
-
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- case c.request <- request:
- }
-
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- case reply := <-request.Reply:
- return reply.Value, reply.Error
- }
-}
diff --git a/internal/bridge/server_manager.go b/internal/bridge/server_manager.go
index acca55c8..9e06a430 100644
--- a/internal/bridge/server_manager.go
+++ b/internal/bridge/server_manager.go
@@ -26,10 +26,10 @@ import (
"github.com/ProtonMail/gluon"
"github.com/ProtonMail/gluon/connector"
"github.com/ProtonMail/gluon/logging"
- "github.com/ProtonMail/proton-bridge/v3/cpc"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
"github.com/ProtonMail/proton-bridge/v3/internal/user"
+ "github.com/ProtonMail/proton-bridge/v3/pkg/cpc"
"github.com/emersion/go-smtp"
"github.com/sirupsen/logrus"
)
@@ -77,31 +77,31 @@ func (sm *ServerManager) Init(bridge *Bridge) error {
func (sm *ServerManager) CloseServers(ctx context.Context) error {
defer sm.requests.Close()
- _, err := sm.requests.SendWithReply(ctx, &smRequestClose{})
+ _, err := sm.requests.Send(ctx, &smRequestClose{})
return err
}
func (sm *ServerManager) RestartIMAP(ctx context.Context) error {
- _, err := sm.requests.SendWithReply(ctx, &smRequestRestartIMAP{})
+ _, err := sm.requests.Send(ctx, &smRequestRestartIMAP{})
return err
}
func (sm *ServerManager) RestartSMTP(ctx context.Context) error {
- _, err := sm.requests.SendWithReply(ctx, &smRequestRestartSMTP{})
+ _, err := sm.requests.Send(ctx, &smRequestRestartSMTP{})
return err
}
func (sm *ServerManager) AddIMAPUser(ctx context.Context, user *user.User) error {
- _, err := sm.requests.SendWithReply(ctx, &smRequestAddIMAPUser{user: user})
+ _, err := sm.requests.Send(ctx, &smRequestAddIMAPUser{user: user})
return err
}
func (sm *ServerManager) RemoveIMAPUser(ctx context.Context, user *user.User, withData bool) error {
- _, err := sm.requests.SendWithReply(ctx, &smRequestRemoveIMAPUser{
+ _, err := sm.requests.Send(ctx, &smRequestRemoveIMAPUser{
user: user,
withData: withData,
})
@@ -110,7 +110,7 @@ func (sm *ServerManager) RemoveIMAPUser(ctx context.Context, user *user.User, wi
}
func (sm *ServerManager) SetGluonDir(ctx context.Context, gluonDir string) error {
- _, err := sm.requests.SendWithReply(ctx, &smRequestSetGluonDir{
+ _, err := sm.requests.Send(ctx, &smRequestSetGluonDir{
dir: gluonDir,
})
@@ -118,7 +118,7 @@ func (sm *ServerManager) SetGluonDir(ctx context.Context, gluonDir string) error
}
func (sm *ServerManager) AddGluonUser(ctx context.Context, conn connector.Connector, passphrase []byte) (string, error) {
- reply, err := cpc.SendWithReplyType[string](ctx, sm.requests, &smRequestAddGluonUser{
+ reply, err := cpc.SendTyped[string](ctx, sm.requests, &smRequestAddGluonUser{
conn: conn,
passphrase: passphrase,
})
@@ -127,7 +127,7 @@ func (sm *ServerManager) AddGluonUser(ctx context.Context, conn connector.Connec
}
func (sm *ServerManager) RemoveGluonUser(ctx context.Context, gluonID string) error {
- _, err := sm.requests.SendWithReply(ctx, &smRequestRemoveGluonUser{
+ _, err := sm.requests.Send(ctx, &smRequestRemoveGluonUser{
userID: gluonID,
})
@@ -165,23 +165,23 @@ func (sm *ServerManager) run(ctx context.Context, bridge *Bridge) {
return
}
- switch r := request.Value.(type) {
+ switch r := request.Value().(type) {
case *smRequestClose:
sm.handleClose(ctx, bridge)
- request.SendReply(ctx, nil, nil)
+ request.Reply(ctx, nil, nil)
return
case *smRequestRestartSMTP:
err := sm.restartSMTP(bridge)
- request.SendReply(ctx, nil, err)
+ request.Reply(ctx, nil, err)
case *smRequestRestartIMAP:
err := sm.restartIMAP(ctx, bridge)
- request.SendReply(ctx, nil, err)
+ request.Reply(ctx, nil, err)
case *smRequestAddIMAPUser:
err := sm.handleAddIMAPUser(ctx, r.user)
- request.SendReply(ctx, nil, err)
+ request.Reply(ctx, nil, err)
if err == nil {
sm.loadedUserCount++
sm.handleLoadedUserCountChange(ctx, bridge)
@@ -189,7 +189,7 @@ func (sm *ServerManager) run(ctx context.Context, bridge *Bridge) {
case *smRequestRemoveIMAPUser:
err := sm.handleRemoveIMAPUser(ctx, r.user, r.withData)
- request.SendReply(ctx, nil, err)
+ request.Reply(ctx, nil, err)
if err == nil {
sm.loadedUserCount--
sm.handleLoadedUserCountChange(ctx, bridge)
@@ -197,15 +197,15 @@ func (sm *ServerManager) run(ctx context.Context, bridge *Bridge) {
case *smRequestSetGluonDir:
err := sm.handleSetGluonDir(ctx, bridge, r.dir)
- request.SendReply(ctx, nil, err)
+ request.Reply(ctx, nil, err)
case *smRequestAddGluonUser:
id, err := sm.handleAddGluonUser(ctx, r.conn, r.passphrase)
- request.SendReply(ctx, id, err)
+ request.Reply(ctx, id, err)
case *smRequestRemoveGluonUser:
err := sm.handleRemoveGluonUser(ctx, r.userID)
- request.SendReply(ctx, nil, err)
+ request.Reply(ctx, nil, err)
}
}
}
diff --git a/pkg/cpc/cpc.go b/pkg/cpc/cpc.go
new file mode 100644
index 00000000..0892c28b
--- /dev/null
+++ b/pkg/cpc/cpc.go
@@ -0,0 +1,129 @@
+// Copyright (c) 2023 Proton AG
+//
+// This file is part of Proton Mail Bridge.
+//
+// Proton Mail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// Proton Mail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with Proton Mail Bridge. If not, see .
+
+package cpc
+
+import (
+ "context"
+ "errors"
+)
+
+var ErrInvalidReplyType = errors.New("reply type does not match")
+
+// Utilities to implement Chanel Procedure Calls. Similar in concept to RPC, but with between go-routines.
+
+// Request contains the data for a request as well as the means to reply to a request.
+type Request struct {
+ value any
+ reply chan reply
+}
+
+// Value returns the request value.
+func (r *Request) Value() any {
+ return r.value
+}
+
+// Reply should be used to send a reply to a given request.
+func (r *Request) Reply(ctx context.Context, value any, err error) {
+ defer close(r.reply)
+
+ select {
+ case <-ctx.Done():
+ case r.reply <- reply{
+ value: value,
+ error: err,
+ }:
+ }
+}
+
+// CPC Channel Procedure Call. A play on RPC, but with channels. Use this type to send requests and wait for replies
+// from a goroutine.
+type CPC struct {
+ request chan *Request
+}
+
+func NewCPC() *CPC {
+ return &CPC{
+ request: make(chan *Request),
+ }
+}
+
+// Receive invokes the function on all the request that arrive.
+func (c *CPC) Receive(ctx context.Context, f func(context.Context, *Request)) {
+ for request := range c.request {
+ f(ctx, request)
+ }
+}
+
+// ReceiveCh returns the channel on which all requests are sent.
+func (c *CPC) ReceiveCh() <-chan *Request {
+ return c.request
+}
+
+// Close closes the CPC channel and no further requests should be made.
+func (c *CPC) Close() {
+ close(c.request)
+}
+
+// Send sends a request which expects a reply.
+func (c *CPC) Send(ctx context.Context, value any) (any, error) {
+ return c.execute(ctx, newRequest(value))
+}
+
+// SendTyped is similar to CPC.Send, but ensure that reply is of the given Type T.
+func SendTyped[T any](ctx context.Context, c *CPC, value any) (T, error) {
+ val, err := c.execute(ctx, newRequest(value))
+ if err != nil {
+ var t T
+ return t, err
+ }
+
+ switch vt := val.(type) {
+ case T:
+ return vt, nil
+ default:
+ var t T
+ return t, ErrInvalidReplyType
+ }
+}
+
+type reply struct {
+ value any
+ error error
+}
+
+func (c *CPC) execute(ctx context.Context, request *Request) (any, error) {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case c.request <- request:
+ }
+
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case r := <-request.reply:
+ return r.value, r.error
+ }
+}
+
+func newRequest(value any) *Request {
+ return &Request{
+ value: value,
+ reply: make(chan reply),
+ }
+}
diff --git a/cpc/cpc_test.go b/pkg/cpc/cpc_test.go
similarity index 83%
rename from cpc/cpc_test.go
rename to pkg/cpc/cpc_test.go
index 8fc70406..102970e5 100644
--- a/cpc/cpc_test.go
+++ b/pkg/cpc/cpc_test.go
@@ -42,22 +42,24 @@ func TestCPC_Receive(t *testing.T) {
wg.Add(1)
cpc.Receive(context.Background(), func(ctx context.Context, request *Request) {
- switch request.Value.(type) {
+ switch request.Value().(type) {
case sendIntRequest:
- request.SendReply(ctx, replyValue, nil)
+ request.Reply(ctx, replyValue, nil)
case quitRequest:
- cpc.Close()
+ request.Reply(ctx, nil, nil)
default:
panic("unknown request")
}
})
}()
- r, err := cpc.SendWithReply(context.Background(), sendIntRequest{})
+ r, err := cpc.Send(context.Background(), sendIntRequest{})
require.NoError(t, err)
require.Equal(t, r, replyValue)
- require.NoError(t, cpc.SendNoReply(context.Background(), quitRequest{}))
+ _, err = cpc.Send(context.Background(), quitRequest{})
+ require.NoError(t, err)
+ cpc.Close()
wg.Wait()
}