Skip to content

Commit f9f07cf

Browse files
support tls to non-tls and non-tls to tls websocket proxy
1 parent e7eb8b8 commit f9f07cf

File tree

2 files changed

+159
-75
lines changed

2 files changed

+159
-75
lines changed

middleware/proxy.go

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -132,40 +132,31 @@ var DefaultProxyConfig = ProxyConfig{
132132
}
133133

134134
func proxyRaw(t *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler {
135+
var dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
136+
if transport, ok := config.Transport.(*http.Transport); ok {
137+
if transport.TLSClientConfig != nil {
138+
d := tls.Dialer{
139+
Config: transport.TLSClientConfig,
140+
}
141+
dialFunc = d.DialContext
142+
}
143+
}
144+
if dialFunc == nil {
145+
var d net.Dialer
146+
dialFunc = d.DialContext
147+
}
148+
135149
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
136150
in, _, err := c.Response().Hijack()
137151
if err != nil {
138152
c.Set("_error", fmt.Errorf("proxy raw, hijack error=%w, url=%s", err, t.URL))
139153
return
140154
}
141155
defer in.Close()
142-
143-
var out net.Conn
144-
if c.IsTLS() {
145-
transport, ok := config.Transport.(*http.Transport)
146-
if !ok {
147-
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, "proxy raw, invalid transport type"))
148-
return
149-
}
150-
151-
if transport.TLSClientConfig == nil {
152-
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, "proxy raw, TLSClientConfig is not set"))
153-
return
154-
}
155-
156-
out, err = tls.Dial("tcp", t.URL.Host, transport.TLSClientConfig)
157-
if err != nil {
158-
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL)))
159-
return
160-
}
161-
defer out.Close()
162-
} else {
163-
out, err = net.Dial("tcp", t.URL.Host)
164-
if err != nil {
165-
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL)))
166-
return
167-
}
168-
defer out.Close()
156+
out, err := dialFunc(c.Request().Context(), "tcp", t.URL.Host)
157+
if err != nil {
158+
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL)))
159+
return
169160
}
170161

171162
// Write header

middleware/proxy_test.go

Lines changed: 141 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -813,14 +813,8 @@ func TestModifyResponseUseContext(t *testing.T) {
813813
assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER"))
814814
}
815815

816-
func TestProxyWithConfigWebSocketTCP(t *testing.T) {
817-
/*
818-
Arrange
819-
*/
820-
e := echo.New()
821-
822-
// Create a WebSocket test server
823-
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
816+
func createSimpleWebSocketServer(serveTLS bool) *httptest.Server {
817+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
824818
wsHandler := func(conn *websocket.Conn) {
825819
defer conn.Close()
826820
for {
@@ -834,15 +828,59 @@ func TestProxyWithConfigWebSocketTCP(t *testing.T) {
834828
}
835829
}
836830
websocket.Server{Handler: wsHandler}.ServeHTTP(w, r)
837-
}))
838-
defer srv.Close()
831+
})
832+
if serveTLS {
833+
return httptest.NewTLSServer(handler)
834+
}
835+
return httptest.NewServer(handler)
836+
}
839837

840-
tgtURL, _ := url.Parse(srv.URL)
841-
balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})
838+
func createSimpleProxyServer(t *testing.T, srv *httptest.Server, serveTLS bool, toTLS bool) *httptest.Server {
839+
e := echo.New()
842840

843-
e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer}))
841+
if toTLS {
842+
// proxy to tls target
843+
tgtURL, _ := url.Parse(srv.URL)
844+
tgtURL.Scheme = "wss"
845+
balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})
844846

847+
defaultTransport, ok := http.DefaultTransport.(*http.Transport)
848+
if !ok {
849+
t.Fatal("Default transport is not of type *http.Transport")
850+
}
851+
transport := defaultTransport.Clone()
852+
transport.TLSClientConfig = &tls.Config{
853+
InsecureSkipVerify: true,
854+
}
855+
e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer, Transport: transport}))
856+
} else {
857+
// proxy to non-TLS target
858+
tgtURL, _ := url.Parse(srv.URL)
859+
balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})
860+
e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer}))
861+
}
862+
863+
if serveTLS {
864+
// serve proxy server with TLS
865+
ts := httptest.NewTLSServer(e)
866+
return ts
867+
}
868+
// serve proxy server without TLS
845869
ts := httptest.NewServer(e)
870+
return ts
871+
}
872+
873+
// TestProxyWithConfigWebSocketNonTLS2NonTLS tests the proxy with non-TLS to non-TLS WebSocket connection.
874+
func TestProxyWithConfigWebSocketNonTLS2NonTLS(t *testing.T) {
875+
/*
876+
Arrange
877+
*/
878+
// Create a WebSocket test server (non-TLS)
879+
srv := createSimpleWebSocketServer(false)
880+
defer srv.Close()
881+
882+
// create proxy server (non-TLS to non-TLS)
883+
ts := createSimpleProxyServer(t, srv, false, false)
846884
defer ts.Close()
847885

848886
tsURL, _ := url.Parse(ts.URL)
@@ -859,7 +897,7 @@ func TestProxyWithConfigWebSocketTCP(t *testing.T) {
859897
defer wsConn.Close()
860898

861899
// Send message
862-
sendMsg := "Hello, WebSocket!"
900+
sendMsg := "Hello, Non TLS WebSocket!"
863901
err = websocket.Message.Send(wsConn, sendMsg)
864902
assert.NoError(t, err)
865903

@@ -873,48 +911,103 @@ func TestProxyWithConfigWebSocketTCP(t *testing.T) {
873911
assert.Equal(t, sendMsg, recvMsg)
874912
}
875913

876-
func TestProxyWithConfigWebSocketTLS(t *testing.T) {
914+
// TestProxyWithConfigWebSocketTLS2TLS tests the proxy with TLS to TLS WebSocket connection.
915+
func TestProxyWithConfigWebSocketTLS2TLS(t *testing.T) {
877916
/*
878917
Arrange
879918
*/
880-
e := echo.New()
881-
882-
// Create a WebSocket test server
883-
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
884-
wsHandler := func(conn *websocket.Conn) {
885-
defer conn.Close()
886-
for {
887-
var msg string
888-
err := websocket.Message.Receive(conn, &msg)
889-
if err != nil {
890-
return
891-
}
892-
// message back to the client
893-
websocket.Message.Send(conn, msg)
894-
}
895-
}
896-
websocket.Server{Handler: wsHandler}.ServeHTTP(w, r)
897-
}))
919+
// Create a WebSocket test server (TLS)
920+
srv := createSimpleWebSocketServer(true)
898921
defer srv.Close()
899922

900-
// create proxy server
901-
tgtURL, _ := url.Parse(srv.URL)
902-
tgtURL.Scheme = "wss"
923+
// create proxy server (TLS to TLS)
924+
ts := createSimpleProxyServer(t, srv, true, true)
925+
defer ts.Close()
903926

904-
balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})
927+
tsURL, _ := url.Parse(ts.URL)
928+
tsURL.Scheme = "wss"
929+
tsURL.Path = "/"
905930

906-
defaultTransport, ok := http.DefaultTransport.(*http.Transport)
907-
if !ok {
908-
t.Fatal("Default transport is not of type *http.Transport")
909-
}
910-
transport := defaultTransport.Clone()
911-
transport.TLSClientConfig = &tls.Config{
912-
InsecureSkipVerify: true,
931+
/*
932+
Act
933+
*/
934+
origin, err := url.Parse(ts.URL)
935+
assert.NoError(t, err)
936+
config := &websocket.Config{
937+
Location: tsURL,
938+
Origin: origin,
939+
TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing
940+
Version: websocket.ProtocolVersionHybi13,
913941
}
914-
e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer, Transport: transport}))
942+
wsConn, err := websocket.DialConfig(config)
943+
assert.NoError(t, err)
944+
defer wsConn.Close()
945+
946+
// Send message
947+
sendMsg := "Hello, TLS to TLS WebSocket!"
948+
err = websocket.Message.Send(wsConn, sendMsg)
949+
assert.NoError(t, err)
950+
951+
// Read response
952+
var recvMsg string
953+
err = websocket.Message.Receive(wsConn, &recvMsg)
954+
assert.NoError(t, err)
955+
assert.Equal(t, sendMsg, recvMsg)
956+
}
957+
958+
// TestProxyWithConfigWebSocketNonTLS2TLS tests the proxy with non-TLS to TLS WebSocket connection.
959+
func TestProxyWithConfigWebSocketNonTLS2TLS(t *testing.T) {
960+
/*
961+
Arrange
962+
*/
963+
964+
// Create a WebSocket test server (TLS)
965+
srv := createSimpleWebSocketServer(true)
966+
defer srv.Close()
967+
968+
// create proxy server (Non-TLS to TLS)
969+
ts := createSimpleProxyServer(t, srv, false, true)
970+
defer ts.Close()
971+
972+
tsURL, _ := url.Parse(ts.URL)
973+
tsURL.Scheme = "ws"
974+
tsURL.Path = "/"
975+
976+
/*
977+
Act
978+
*/
979+
// Connect to the proxy WebSocket
980+
wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/")
981+
assert.NoError(t, err)
982+
defer wsConn.Close()
983+
984+
// Send message
985+
sendMsg := "Hello, Non TLS to TLS WebSocket!"
986+
err = websocket.Message.Send(wsConn, sendMsg)
987+
assert.NoError(t, err)
988+
989+
/*
990+
Assert
991+
*/
992+
// Read response
993+
var recvMsg string
994+
err = websocket.Message.Receive(wsConn, &recvMsg)
995+
assert.NoError(t, err)
996+
assert.Equal(t, sendMsg, recvMsg)
997+
}
998+
999+
// TestProxyWithConfigWebSocketTLSToNoneTLS tests the proxy with TLS to non-TLS WebSocket connection. (TLS termination)
1000+
func TestProxyWithConfigWebSocketTLS2NonTLS(t *testing.T) {
1001+
/*
1002+
Arrange
1003+
*/
1004+
1005+
// Create a WebSocket test server (non-TLS)
1006+
srv := createSimpleWebSocketServer(false)
1007+
defer srv.Close()
9151008

916-
// Start test server
917-
ts := httptest.NewTLSServer(e)
1009+
// create proxy server (TLS to non-TLS)
1010+
ts := createSimpleProxyServer(t, srv, true, false)
9181011
defer ts.Close()
9191012

9201013
tsURL, _ := url.Parse(ts.URL)
@@ -937,7 +1030,7 @@ func TestProxyWithConfigWebSocketTLS(t *testing.T) {
9371030
defer wsConn.Close()
9381031

9391032
// Send message
940-
sendMsg := "Hello, TLS WebSocket!"
1033+
sendMsg := "Hello, TLS to NoneTLS WebSocket!"
9411034
err = websocket.Message.Send(wsConn, sendMsg)
9421035
assert.NoError(t, err)
9431036

0 commit comments

Comments
 (0)