diff --git a/main.go b/main.go index f3dd57c..26d3d48 100644 --- a/main.go +++ b/main.go @@ -7,7 +7,6 @@ import ( "fmt" "math/rand" "net/http" - "os" "github.com/gorilla/mux" "github.com/rs/cors" @@ -26,7 +25,7 @@ func main() { err := json.Unmarshal(greetingsJson, &greetings) if err != nil { fmt.Printf("error loading greetings: %s\n", err) - os.Exit(1) + panic(err) } router := mux.NewRouter() @@ -57,6 +56,20 @@ func main() { } }).Methods("GET") + router.HandleFunc("/all", func(w http.ResponseWriter, r *http.Request) { + fmt.Printf("got /all request from %s\n", r.RemoteAddr) + w.Header().Set("Content-Type", "application/json") + jsonGreetings, err := json.Marshal(greetings) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + _, err = w.Write(jsonGreetings) + if err != nil { + panic(err) + } + }).Methods("GET") + c := cors.New(cors.Options{ AllowedOrigins: []string{ "http://greetings.kylepenfound.com", @@ -70,7 +83,7 @@ func main() { fmt.Printf("server closed\n") } else if err != nil { fmt.Printf("error starting server: %s\n", err) - os.Exit(1) + panic(err) } } diff --git a/main_test.go b/main_test.go index e05e339..2a3efa5 100644 --- a/main_test.go +++ b/main_test.go @@ -3,9 +3,11 @@ package main import ( "encoding/json" "fmt" - "os" + "net/http" + "net/http/httptest" "testing" + "github.com/gorilla/mux" "gotest.tools/v3/assert" ) @@ -14,7 +16,7 @@ func TestSelectGreeting(t *testing.T) { err := json.Unmarshal(greetingsJson, &greetings) if err != nil { fmt.Printf("error loading greetings: %s\n", err) - os.Exit(1) + panic(err) } english := &Greeting{ @@ -49,3 +51,46 @@ func TestFormatResponse(t *testing.T) { formatted := FormatResponse(g) assert.Equal(t, "{\"greeting\":\"Hello, World!\"}", formatted) } + +func TestAllEndpoint(t *testing.T) { + var greetings []*Greeting + err := json.Unmarshal(greetingsJson, &greetings) + if err != nil { + fmt.Printf("error loading greetings: %s\n", err) + panic(err) + } + + router := mux.NewRouter() + router.HandleFunc("/all", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + jsonGreetings, err := json.Marshal(greetings) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + _, err = w.Write(jsonGreetings) + if err != nil { + panic(err) + } + }).Methods("GET") + + rec := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/all", nil) + if err != nil { + t.Fatal(err) + } + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected status OK; got %v", rec.Code) + } + + var actualGreetings []*Greeting + err = json.Unmarshal(rec.Body.Bytes(), &actualGreetings) + if err != nil { + t.Fatal(err) + } + + assert.DeepEqual(t, greetings, actualGreetings) +}