diff --git a/context.go b/context.go index 6bdd5c7..0ad79be 100644 --- a/context.go +++ b/context.go @@ -65,6 +65,20 @@ func (ctx *Context) onAuth(authType string, user string, pass string) bool { return ctx.Prx.OnAuth(ctx, authType, user, pass) } +func (ctx *Context) onDial(network, addr string) (c net.Conn, err error) { + defer func() { + if err, ok := recover().(error); ok { + ctx.doError("Dial", ErrPanic, err) + } + }() + if ctx.Prx.OnDial != nil { + return ctx.Prx.OnDial(network, addr) + } else { + return net.Dial(network, addr) + } + +} + func (ctx *Context) onConnect(host string) (ConnectAction ConnectAction, newHost string) { defer func() { @@ -208,7 +222,8 @@ func (ctx *Context) doConnect(w http.ResponseWriter, r *http.Request) (b bool) { ctx.ConnectHost = host switch ctx.ConnectAction { case ConnectProxy: - conn, err := net.Dial("tcp", host) + + conn, err := ctx.onDial("tcp", host) if err != nil { hijConn.Write([]byte("HTTP/1.1 404 Not Found\r\n\r\n")) hijConn.Close() diff --git a/proxy.go b/proxy.go index 0dcf09f..777300d 100644 --- a/proxy.go +++ b/proxy.go @@ -2,6 +2,7 @@ package httpproxy import ( "crypto/tls" + "net" "net/http" "sync/atomic" ) @@ -40,6 +41,8 @@ type Proxy struct { OnConnect func(ctx *Context, host string) (ConnectAction ConnectAction, newHost string) + OnDial func(network string, addr string) (c net.Conn, err error) + // Request callback. It greets remote request. // If it returns non-nil response, stops processing remote request. OnRequest func(ctx *Context, req *http.Request) (resp *http.Response)