diff --git a/README.md b/README.md index 6ae7f8e..b0bb1b9 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ Flags: -i, --interval=250ms Interval for statistics calculation (reqlog mode) --preallocate=1000 Number of requests in req log to preallocate memory for per connection (reqlog mode) + --method=GET The HTTP request method (GET, POST, PUT, PATCH or DELETE) --version Show application version. Args: diff --git a/args.go b/args.go index ad3c0f9..bea237f 100644 --- a/args.go +++ b/args.go @@ -32,6 +32,7 @@ var ( preallocate = kingpin.Flag("preallocate", "Number of requests in req log to preallocate memory for per connection (reqlog mode)"). Default("1000"). Int() + method = kingpin.Flag("method", "The HTTP request method (GET, POST, PUT, PATCH or DELETE)").Default("GET").Enum("GET", "POST", "PUT", "PATCH", "DELETE") target = kingpin.Arg("target", "HTTP target URL").Required().String() ) diff --git a/gocannon.go b/gocannon.go index 5eeaef0..98c76ed 100644 --- a/gocannon.go +++ b/gocannon.go @@ -39,7 +39,7 @@ func runGocannon() error { for connectionID := 0; connectionID < n; connectionID++ { go func(c *fasthttp.HostClient, cid int) { for { - code, start, end := performRequest(c, *target) + code, start, end := performRequest(c, *target, *method) if end >= stop { break } diff --git a/http.go b/http.go index 02c6296..25202f3 100644 --- a/http.go +++ b/http.go @@ -19,7 +19,6 @@ func newHTTPClient( timeout time.Duration, connections int, ) (*fasthttp.HostClient, error) { - c := new(fasthttp.HostClient) u, err := url.ParseRequestURI(target) if err != nil { return nil, ErrWrongTarget @@ -27,7 +26,7 @@ func newHTTPClient( if u.Scheme != "http" { return nil, ErrUnsupportedProtocol } - c = &fasthttp.HostClient{ + c := &fasthttp.HostClient{ Addr: u.Host, MaxConns: int(connections), ReadTimeout: timeout, @@ -40,14 +39,14 @@ func newHTTPClient( return c, nil } -func performRequest(c *fasthttp.HostClient, target string) ( +func performRequest(c *fasthttp.HostClient, target string, method string) ( code int, start int64, end int64, ) { req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() req.URI().SetScheme("http") - req.Header.SetMethod("GET") + req.Header.SetMethod(method) req.SetRequestURI(target) start = makeTimestamp() diff --git a/http_test.go b/http_test.go index 0fc71c6..ce36e6b 100644 --- a/http_test.go +++ b/http_test.go @@ -35,11 +35,13 @@ func TestPerformRequest(t *testing.T) { c, _ := newHTTPClient("http://localhost:3000/", timeout, 10) - codeOk, _, _ := performRequest(c, "http://localhost:3000/") - codeISE, _, _ := performRequest(c, "http://localhost:3000/error") - codeTimeout, _, _ := performRequest(c, "http://localhost:3000/timeout") + codeOk, _, _ := performRequest(c, "http://localhost:3000/", "GET") + codePost, _, _ := performRequest(c, "http://localhost:3000/postonly", "POST") + codeISE, _, _ := performRequest(c, "http://localhost:3000/error", "GET") + codeTimeout, _, _ := performRequest(c, "http://localhost:3000/timeout", "GET") assert.Equal(t, 200, codeOk) + assert.Equal(t, 200, codePost) assert.Equal(t, http.StatusInternalServerError, codeISE) assert.Equal(t, 0, codeTimeout) } diff --git a/integration_test.go b/integration_test.go index 897a203..654a0e4 100644 --- a/integration_test.go +++ b/integration_test.go @@ -34,7 +34,7 @@ func TestGocannon(t *testing.T) { for connectionID := 0; connectionID < conns; connectionID++ { go func(c *fasthttp.HostClient, cid int) { for { - code, start, end := performRequest(c, target) + code, start, end := performRequest(c, target, "GET") if end >= stop { break } diff --git a/target_test.go b/target_test.go index 0747b63..60c50e4 100644 --- a/target_test.go +++ b/target_test.go @@ -15,6 +15,15 @@ func TestMain(m *testing.M) { fmt.Fprintf(w, "Hello") }) + http.HandleFunc("/postonly", func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, "Wrong method") + } else { + fmt.Fprintf(w, "Ok") + } + }) + http.HandleFunc("/error", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) fmt.Fprintf(w, "Oooops...")