Skip to content

Handle reconnection for multiplexed transport #43

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 5 additions & 15 deletions client/temporal_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type (
}

clientFactory struct {
transportManger transport.TransportManager
clientTransport transport.ClientTransport
logger log.Logger
}

Expand Down Expand Up @@ -92,24 +92,19 @@ func (c *clientProvider) GetWorkflowServiceClient() (workflowservice.WorkflowSer

// NewFactory creates an instance of client factory that knows how to dispatch RPC calls.
func NewClientFactory(
transportManager transport.TransportManager,
clientTransport transport.ClientTransport,
logger log.Logger,
) ClientFactory {
return &clientFactory{
transportManger: transportManager,
clientTransport: clientTransport,
logger: logger,
}
}

func (cf *clientFactory) NewRemoteAdminClient(
clientConfig config.ProxyClientConfig,
) (adminservice.AdminServiceClient, error) {
clientTransport, err := cf.transportManger.CreateClientTransport(clientConfig)
if err != nil {
return nil, err
}

connection, err := clientTransport.Connect()
connection, err := cf.clientTransport.Connect()
if err != nil {
return nil, err
}
Expand All @@ -120,12 +115,7 @@ func (cf *clientFactory) NewRemoteAdminClient(
func (cf *clientFactory) NewRemoteWorkflowServiceClient(
clientConfig config.ProxyClientConfig,
) (workflowservice.WorkflowServiceClient, error) {
clientTransport, err := cf.transportManger.CreateClientTransport(clientConfig)
if err != nil {
return nil, err
}

connection, err := clientTransport.Connect()
connection, err := cf.clientTransport.Connect()
if err != nil {
return nil, err
}
Expand Down
16 changes: 16 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,22 @@ type (
}
)

func (c ProxyClientConfig) IsMux() bool {
return c.Type == MuxTransport
}

func (c ProxyClientConfig) IsTCP() bool {
return c.Type == TCPTransport
}

func (c ProxyServerConfig) IsMux() bool {
return c.Type == MuxTransport
}

func (c ProxyServerConfig) IsTCP() bool {
return c.Type == TCPTransport
}

func newConfigProvider(ctx *cli.Context) (ConfigProvider, error) {
s2sConfig, err := LoadConfig[S2SProxyConfig](ctx.String(ConfigPathFlag))
if err != nil {
Expand Down
235 changes: 170 additions & 65 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,33 @@
package proxy

import (
"fmt"

"github.com/temporalio/s2s-proxy/client"
"github.com/temporalio/s2s-proxy/config"
"github.com/temporalio/s2s-proxy/encryption"
"github.com/temporalio/s2s-proxy/interceptor"
"github.com/temporalio/s2s-proxy/transport"
"go.temporal.io/server/common/log"
"go.temporal.io/server/common/log/tag"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)

type (
ProxyServer struct {
config config.ProxyConfig
opts proxyOptions
transManager transport.TransportManager
logger log.Logger
server *TemporalAPIServer
shutDownCh chan struct{}
}

Proxy struct {
config config.S2SProxyConfig
transportManager transport.TransportManager
outboundServer *TemporalAPIServer
inboundServer *TemporalAPIServer
config config.S2SProxyConfig
outboundServer *ProxyServer
inboundServer *ProxyServer
}

proxyOptions struct {
Expand All @@ -25,51 +36,6 @@ type (
}
)

func NewProxy(
configProvider config.ConfigProvider,
transportManager transport.TransportManager,
logger log.Logger,
) (*Proxy, error) {
s2sConfig := configProvider.GetS2SProxyConfig()
var err error

// Establish underlying connection first before start proxy server.
if err := transportManager.Start(); err != nil {
return nil, err
}

proxy := Proxy{
config: s2sConfig,
transportManager: transportManager,
}

// Proxy consists of two grpc servers: inbound and outbound. The flow looks like the following:
// local server -> proxy(outbound) -> remote server
// local server <- proxy(inbound) <- remote server
//
// Here a remote server can be another proxy as well.
// server-a <-> proxy-a <-> proxy-b <-> server-b
if s2sConfig.Outbound != nil {
if proxy.outboundServer, err = proxy.createServer(*s2sConfig.Outbound, logger, proxyOptions{
IsInbound: false,
Config: s2sConfig,
}); err != nil {
return nil, err
}
}

if s2sConfig.Inbound != nil {
if proxy.inboundServer, err = proxy.createServer(*s2sConfig.Inbound, logger, proxyOptions{
IsInbound: true,
Config: s2sConfig,
}); err != nil {
return nil, err
}
}

return &proxy, nil
}

func makeServerOptions(logger log.Logger, cfg config.ProxyConfig, isInbound bool) ([]grpc.ServerOption, error) {
unaryInterceptors := []grpc.UnaryServerInterceptor{}
streamInterceptors := []grpc.StreamServerInterceptor{}
Expand Down Expand Up @@ -102,38 +68,179 @@ func makeServerOptions(logger log.Logger, cfg config.ProxyConfig, isInbound bool
return opts, nil
}

func (s *Proxy) createServer(cfg config.ProxyConfig, logger log.Logger, opts proxyOptions) (*TemporalAPIServer, error) {
serverOpts, err := makeServerOptions(logger, cfg, opts.IsInbound)
if err != nil {
return nil, err
}
func (ps *ProxyServer) startServer(
serverTransport transport.ServerTransport,
clientTransport transport.ClientTransport,
) error {
cfg := ps.config
opts := ps.opts
logger := ps.logger

serverTransport, err := s.transportManager.CreateServerTransport(cfg.Server)
serverOpts, err := makeServerOptions(logger, cfg, opts.IsInbound)
if err != nil {
return nil, err
return err
}

clientFactory := client.NewClientFactory(s.transportManager, logger)
return NewTemporalAPIServer(
clientFactory := client.NewClientFactory(clientTransport, logger)
ps.server = NewTemporalAPIServer(
cfg.Name,
cfg.Server,
NewAdminServiceProxyServer(cfg.Name, cfg.Client, clientFactory, opts, logger),
NewWorkflowServiceProxyServer(cfg.Name, cfg.Client, clientFactory, logger),
serverOpts,
serverTransport,
logger,
), nil
)

ps.server.Start()
return nil
}

func (ps *ProxyServer) stopServer() {
if ps.server != nil {
ps.server.Stop()
}
}

func (ps *ProxyServer) start() error {
serverConfig := ps.config.Server
clientConfig := ps.config.Client

if serverConfig.IsMux() && clientConfig.IsMux() {
return fmt.Errorf("ProxyServer server and client can't both be multiplexed connection.")
}
var serverTransport transport.ServerTransport
var clientTransport transport.ClientTransport

var openMuxTransport func() (transport.MuxTransport, error)
if serverConfig.IsTCP() {
serverTransport = transport.NewTCPServerTransport(serverConfig.TCPServerSetting)
} else {
openMuxTransport = func() (transport.MuxTransport, error) {
muxTransport, err := ps.transManager.Open(serverConfig.MuxTransportName)
if err != nil {
return nil, err
}

serverTransport = muxTransport
return muxTransport, nil
}
}

if clientConfig.IsTCP() {
clientTransport = transport.NewTCPClientTransport(clientConfig.TCPClientSetting)
} else {
openMuxTransport = func() (transport.MuxTransport, error) {
muxTransport, err := ps.transManager.Open(clientConfig.MuxTransportName)
if err != nil {
return nil, err
}

clientTransport = muxTransport
return muxTransport, nil
}
}

if openMuxTransport == nil {
return ps.startServer(serverTransport, clientTransport)
}

// Manage multiplexed connection via connection manager
go func() {
for {
muxTransport, err := openMuxTransport()
if err != nil {
ps.logger.Error("Failed to open mux transport", tag.Error(err))
return
}

ps.startServer(serverTransport, clientTransport)
select {
case <-muxTransport.CloseChan():
// stop server and try re-open transport
ps.stopServer()
case <-ps.shutDownCh:
ps.stopServer()
muxTransport.Close()
return
}
}
}()

return nil
}

func (ps *ProxyServer) stop() {
close(ps.shutDownCh)
}

func newProxyServer(
cfg config.ProxyConfig,
opts proxyOptions,
transManager transport.TransportManager,
logger log.Logger,
) *ProxyServer {
return &ProxyServer{
config: cfg,
opts: opts,
transManager: transManager,
logger: logger,
shutDownCh: make(chan struct{}),
}
}

func NewProxy(
configProvider config.ConfigProvider,
transManager transport.TransportManager,
logger log.Logger,
) (*Proxy, error) {
s2sConfig := configProvider.GetS2SProxyConfig()
proxy := &Proxy{
config: s2sConfig,
}

// Proxy consists of two grpc servers: inbound and outbound. The flow looks like the following:
// local server -> proxy(outbound) -> remote server
// local server <- proxy(inbound) <- remote server
//
// Here a remote server can be another proxy as well.
// server-a <-> proxy-a <-> proxy-b <-> server-b
if s2sConfig.Outbound != nil {
proxy.outboundServer = newProxyServer(
*s2sConfig.Outbound,
proxyOptions{
IsInbound: false,
Config: s2sConfig,
},
transManager,
logger,
)
}

if s2sConfig.Inbound != nil {
proxy.inboundServer = newProxyServer(
*s2sConfig.Inbound,
proxyOptions{
IsInbound: true,
Config: s2sConfig,
},
transManager,
logger,
)
}

return proxy, nil
}

func (s *Proxy) Start() error {
if s.outboundServer != nil {
if err := s.outboundServer.Start(); err != nil {
if err := s.outboundServer.start(); err != nil {
return err
}
}

if s.inboundServer != nil {
if err := s.inboundServer.Start(); err != nil {
if err := s.inboundServer.start(); err != nil {
return err
}
}
Expand All @@ -143,11 +250,9 @@ func (s *Proxy) Start() error {

func (s *Proxy) Stop() {
if s.inboundServer != nil {
s.inboundServer.Stop()
s.inboundServer.stop()
}
if s.outboundServer != nil {
s.outboundServer.Stop()
s.outboundServer.stop()
}

s.transportManager.Stop()
}
6 changes: 2 additions & 4 deletions proxy/temporal_api_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func NewTemporalAPIServer(
}
}

func (s *TemporalAPIServer) Start() error {
func (s *TemporalAPIServer) Start() {
adminservice.RegisterAdminServiceServer(s.server, s.adminHandler)
workflowservice.RegisterWorkflowServiceServer(s.server, s.workflowserviceHandler)

Expand All @@ -61,12 +61,10 @@ func (s *TemporalAPIServer) Start() error {
// It should not happen if grpc server is based on mux server or normal TCP connection.
s.logger.Warn("grpc server received EOF error")
} else {
s.logger.Fatal("grpc server fatal error ", tag.Error(err))
s.logger.Error("grpc server fatal error ", tag.Error(err))
}
}
}()

return nil
}

func (s *TemporalAPIServer) Stop() {
Expand Down
Loading
Loading