diff --git a/adapter.go b/adapter.go new file mode 100644 index 0000000..a38d53d --- /dev/null +++ b/adapter.go @@ -0,0 +1,68 @@ +package debugserver + +import ( + lager "code.cloudfoundry.org/lager/v3" + "errors" + "net/http" + "strings" +) + +// zapLogLevelController is an interface that defines a method to set the minimum log level. +type zapLogLevelController interface { + SetMinLevel(level lager.LogLevel) +} + +// LagerAdapter is an adapter for the ReconfigurableSinkInterface to work with lager.LogLevel. +type LagerAdapter struct { + Sink ReconfigurableSinkInterface +} + +// SetMinLevel sets the minimum log level for the LagerAdapter. +func (l *LagerAdapter) SetMinLevel(level lager.LogLevel) { + l.Sink.SetMinLevel(level) +} + +// normalizeLogLevel returns a single value that represents +// various forms of the same input level. For example: +// "0", "d", "debug", all of these represents debug log level. +func normalizeLogLevel(input string) string { + switch strings.ToLower(strings.TrimSpace(input)) { + case "0", "d", "debug": + return "debug" + case "1", "i", "info": + return "info" + case "2", "w", "warn": + return "warn" + case "3", "e", "error": + return "error" + case "4", "f", "fatal": + return "fatal" + default: + return "" + } +} + +// validateAndNormalize does two things: +// It validates the incoming request is HTTP type, uses POST method and has non-nil level specified. +// It also normalizes the various forms of the same log level type. For ex: 0, d, debug are all same. +func validateAndNormalize(w http.ResponseWriter, r *http.Request, level []byte) (string, error) { + if r.Method != http.MethodPost { + return "", errors.New("method not allowed, use POST") + } + + if r.TLS != nil { + return "", errors.New("invalid scheme, https is not allowed") + } + + if len(level) == 0 { + return "", errors.New("log level cannot be empty") + } + + input := strings.TrimSpace(string(level)) + normalized := normalizeLogLevel(input) + if normalized == "" { + return "", errors.New("invalid log level: " + string(level)) + } + + return normalized, nil +} diff --git a/cf_debug_server_test.go b/cf_debug_server_test.go index 063d8d7..84f24f3 100644 --- a/cf_debug_server_test.go +++ b/cf_debug_server_test.go @@ -1,12 +1,14 @@ package debugserver_test import ( - "bytes" + "crypto/tls" "flag" "fmt" + "io" "net" "net/http" - "strconv" + "net/http/httptest" + "strings" cf_debug_server "code.cloudfoundry.org/debugserver" lager "code.cloudfoundry.org/lager/v3" @@ -113,45 +115,110 @@ var _ = Describe("CF Debug Server", func() { Expect(netErr.Op).To(Equal("listen")) }) }) + }) - Context("checking log-level endpoint", func() { - validForms := map[lager.LogLevel][]string{ - lager.DEBUG: []string{"debug", "DEBUG", "d", strconv.Itoa(int(lager.DEBUG))}, - lager.INFO: []string{"info", "INFO", "i", strconv.Itoa(int(lager.INFO))}, - lager.ERROR: []string{"error", "ERROR", "e", strconv.Itoa(int(lager.ERROR))}, - lager.FATAL: []string{"fatal", "FATAL", "f", strconv.Itoa(int(lager.FATAL))}, - } + Describe("checking log-level endpoint with various inputs", func() { + var ( + req *http.Request + writer *httptest.ResponseRecorder + ) - //This will add another 16 unit tests to the suit - for level, acceptedForms := range validForms { - for _, form := range acceptedForms { - testLevel := level - testForm := form + BeforeEach(func() { + writer = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://%s/log-level", address), nil) + }) - It("can reconfigure the given sink with "+form, func() { - var err error - process, err = cf_debug_server.Run(address, sink) - Expect(err).NotTo(HaveOccurred()) + Context("valid log levels", func() { + DescribeTable("returns normalized log level", + func(input string, expected string) { + req.Body = io.NopCloser(strings.NewReader(input)) + levelBytes, _ := io.ReadAll(req.Body) + + actual, err := cf_debug_server.ValidateAndNormalize(writer, req, levelBytes) + Expect(err).ToNot(HaveOccurred()) + Expect(actual).To(Equal(expected)) + }, + + // Debug + Entry("debug - 0", "0", "debug"), + Entry("debug - d", "d", "debug"), + Entry("debug - debug", "debug", "debug"), + Entry("debug - DEBUG", "DEBUG", "debug"), + Entry("debug - DeBuG", "DeBuG", "debug"), + + // Info + Entry("info - 1", "1", "info"), + Entry("info - i", "i", "info"), + Entry("info - info", "info", "info"), + Entry("info - INFO", "INFO", "info"), + Entry("info - InFo", "InFo", "info"), + + // Warn + Entry("warn - 2", "2", "warn"), + Entry("warn - w", "w", "warn"), + Entry("warn - warn", "warn", "warn"), + Entry("warn - WARN", "WARN", "warn"), + Entry("warn - wARn", "wARn", "warn"), + + // Error + Entry("error - 3", "3", "error"), + Entry("error - e", "e", "error"), + Entry("error - error", "error", "error"), + Entry("error - ERROR", "ERROR", "error"), + Entry("error - eRroR", "eRroR", "error"), + + // Fatal + Entry("fatal - 4", "4", "fatal"), + Entry("fatal - f", "f", "fatal"), + Entry("fatal - fatal", "fatal", "fatal"), + Entry("fatal - FATAL", "FATAL", "fatal"), + Entry("fatal - FaTaL", "FaTaL", "fatal"), + ) + }) - sink.Log(lager.LogFormat{LogLevel: testLevel, Message: "hello before level change"}) - Eventually(logBuf).ShouldNot(gbytes.Say("hello before level change")) + Context("invalid log levels", func() { + It("fails on unsupported level", func() { + level := []byte("invalid") + actual, err := cf_debug_server.ValidateAndNormalize(writer, req, level) + Expect(err).To(HaveOccurred()) + Expect(actual).To(BeEmpty()) + }) - request, err := http.NewRequest("PUT", fmt.Sprintf("http://%s/log-level", address), bytes.NewBufferString(testForm)) + It("fails on empty level", func() { + level := []byte("") + actual, err := cf_debug_server.ValidateAndNormalize(writer, req, level) + Expect(err).To(HaveOccurred()) + Expect(actual).To(BeEmpty()) + }) + }) - Expect(err).NotTo(HaveOccurred()) + Context("invalid request method", func() { + It("returns error for non-POST", func() { + req.Method = http.MethodGet + actual, err := cf_debug_server.ValidateAndNormalize(writer, req, []byte("info")) + Expect(err).To(MatchError(ContainSubstring("method not allowed"))) + Expect(actual).To(BeEmpty()) + }) + }) - response, err := http.DefaultClient.Do(request) - Expect(err).NotTo(HaveOccurred()) + Context("invalid TLS scheme", func() { + It("returns error if TLS is used", func() { + req.TLS = &tls.ConnectionState{} + actual, err := cf_debug_server.ValidateAndNormalize(writer, req, []byte("debug")) + Expect(err).To(MatchError(ContainSubstring("invalid scheme"))) + Expect(actual).To(BeEmpty()) + }) + }) - Expect(response.StatusCode).To(Equal(http.StatusOK)) - response.Body.Close() + It("returns error if the request is made over HTTPS", func() { + // Simulate HTTPS by assigning a non-nil TLS connection state + req.TLS = &tls.ConnectionState{} + actual, err := cf_debug_server.ValidateAndNormalize(writer, req, []byte("debug")) - sink.Log(lager.LogFormat{LogLevel: testLevel, Message: "Logs sent with log-level " + testForm}) - Eventually(logBuf).Should(gbytes.Say("Logs sent with log-level " + testForm)) - }) - } - } + Expect(err).To(MatchError(ContainSubstring("invalid scheme"))) + Expect(actual).To(BeEmpty()) }) }) -}) + +}) \ No newline at end of file diff --git a/cf_debug_server_testhelper.go b/cf_debug_server_testhelper.go new file mode 100644 index 0000000..034ccd9 --- /dev/null +++ b/cf_debug_server_testhelper.go @@ -0,0 +1,12 @@ +//go:build test + +package debugserver + +import ( + "net/http" +) + +// Exported only for tests +func ValidateAndNormalize(w http.ResponseWriter, r *http.Request, level []byte) (string, error) { + return validateAndNormalize(w, r, level) +} diff --git a/server.go b/server.go index e1fbd4e..d93fab7 100644 --- a/server.go +++ b/server.go @@ -38,16 +38,19 @@ func DebugAddress(flags *flag.FlagSet) string { if dbgFlag == nil { return "" } - return dbgFlag.Value.String() } -func Runner(address string, sink ReconfigurableSinkInterface) ifrit.Runner { - return http_server.New(address, Handler(sink)) +// Run starts the debug server with the provided address and log controller. +// Run() -> runProcess() -> Runner() -> http_server.New() -> Handler() +func Run(address string, zapCtrl zapLogLevelController) (ifrit.Process, error) { + return runProcess(address, &LagerAdapter{zapCtrl}) } -func Run(address string, sink ReconfigurableSinkInterface) (ifrit.Process, error) { - p := ifrit.Invoke(Runner(address, sink)) +// runProcess starts the debug server and returns the process. +// It invokes the Runner with the provided address and log controller. +func runProcess(address string, zapCtrl zapLogLevelController) (ifrit.Process, error) { + p := ifrit.Invoke(Runner(address, zapCtrl)) select { case <-p.Ready(): case err := <-p.Wait(): @@ -56,7 +59,12 @@ func Run(address string, sink ReconfigurableSinkInterface) (ifrit.Process, error return p, nil } -func Handler(sink ReconfigurableSinkInterface) http.Handler { +// Runner creates an ifrit.Runner for the debug server with the provided address and log controller. +func Runner(address string, zapCtrl zapLogLevelController) ifrit.Runner { + return http_server.New(address, Handler(zapCtrl)) +} + +func Handler(zapCtrl zapLogLevelController) http.Handler { mux := http.NewServeMux() mux.Handle("/debug/pprof/", http.HandlerFunc(pprof.Index)) mux.Handle("/debug/pprof/trace", http.HandlerFunc(pprof.Trace)) @@ -64,20 +72,40 @@ func Handler(sink ReconfigurableSinkInterface) http.Handler { mux.Handle("/debug/pprof/profile", http.HandlerFunc(pprof.Profile)) mux.Handle("/debug/pprof/symbol", http.HandlerFunc(pprof.Symbol)) mux.Handle("/log-level", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Read the log level from the request body. level, err := io.ReadAll(r.Body) if err != nil { + http.Error(w, "Failed to read body", http.StatusBadRequest) return } - - switch string(level) { - case "debug", "DEBUG", "d", strconv.Itoa(int(lager.DEBUG)): - sink.SetMinLevel(lager.DEBUG) - case "info", "INFO", "i", strconv.Itoa(int(lager.INFO)): - sink.SetMinLevel(lager.INFO) - case "error", "ERROR", "e", strconv.Itoa(int(lager.ERROR)): - sink.SetMinLevel(lager.ERROR) - case "fatal", "FATAL", "f", strconv.Itoa(int(lager.FATAL)): - sink.SetMinLevel(lager.FATAL) + // Validate the log level request. + var normalizedLevel string + if normalizedLevel, err = validateAndNormalize(w, r, level); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + // Convert the log level to lager.LogLevel. + if normalizedLevel == "warn" { + // Note that zapcore.WarnLevel is not directly supported by lager. + // And lager does not have a separate WARN level, it uses INFO for warnings. + // So to set the minimum level to "warn" we send an Invalid log level of 99, + // which hits the default case in the SetMinLevel method. + // This is a workaround to ensure that the log level is set correctly. + zapCtrl.SetMinLevel(lager.LogLevel(99)) + } else { + lagerLogLevel, err := lager.LogLevelFromString(normalizedLevel) + if err != nil { + http.Error(w, "Invalid log level: "+err.Error(), http.StatusBadRequest) + return + } + zapCtrl.SetMinLevel(lagerLogLevel) + } + // Respond with a success message. + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "text/plain") + w.Write([]byte("/log-level was invoked with Level: " + normalizedLevel + "\n")) + if normalizedLevel == "fatal" { + w.Write([]byte("Note: Fatal logs are reported as error logs in the Gorouter logs.\n")) } })) mux.Handle("/block-profile-rate", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -122,4 +150,4 @@ func Handler(sink ReconfigurableSinkInterface) http.Handler { })) return mux -} +} \ No newline at end of file