Skip to content

Commit

Permalink
V1.1 more post support (#454)
Browse files Browse the repository at this point in the history
* use methods bitmap when possible for performance

* relocate query param funcs to params pkg

* broaden POST support and search/replace func reloc

* use Reader instead of Buffer for read-only bodies
  • Loading branch information
James Ranson authored Jun 3, 2020
1 parent cee6faf commit c11c41b
Show file tree
Hide file tree
Showing 26 changed files with 324 additions and 220 deletions.
12 changes: 6 additions & 6 deletions pkg/proxy/engines/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ import (
"github.com/tricksterproxy/trickster/pkg/proxy/errors"
"github.com/tricksterproxy/trickster/pkg/proxy/headers"
oo "github.com/tricksterproxy/trickster/pkg/proxy/origins/options"
"github.com/tricksterproxy/trickster/pkg/proxy/params"
po "github.com/tricksterproxy/trickster/pkg/proxy/paths/options"
"github.com/tricksterproxy/trickster/pkg/proxy/request"
tt "github.com/tricksterproxy/trickster/pkg/proxy/timeconv"
"github.com/tricksterproxy/trickster/pkg/sort/times"
"github.com/tricksterproxy/trickster/pkg/timeseries"
Expand Down Expand Up @@ -283,7 +283,7 @@ func parseDuration(input string) (time.Duration, error) {
func (c *TestClient) ParseTimeRangeQuery(r *http.Request) (*timeseries.TimeRangeQuery, error) {

trq := &timeseries.TimeRangeQuery{Extent: timeseries.Extent{}}
qp := r.URL.Query()
qp, _, _ := params.GetRequestValues(r)

trq.Statement = qp.Get(upQuery)
if trq.Statement == "" {
Expand Down Expand Up @@ -359,10 +359,10 @@ func (c *TestClient) BuildUpstreamURL(r *http.Request) *url.URL {

// SetExtent will change the upstream request query to use the provided Extent
func (c *TestClient) SetExtent(r *http.Request, trq *timeseries.TimeRangeQuery, extent *timeseries.Extent) {
v, _, _ := request.GetRequestValues(r)
v, _, _ := params.GetRequestValues(r)
v.Set(upStart, strconv.FormatInt(extent.Start.Unix(), 10))
v.Set(upEnd, strconv.FormatInt(extent.End.Unix(), 10))
request.SetRequestValues(r, v)
params.SetRequestValues(r, v)
}

// FastForwardRequest returns an *http.Request crafted to collect Fast Forward
Expand All @@ -382,7 +382,7 @@ func (c *TestClient) FastForwardRequest(r *http.Request) (*http.Request, error)
if strings.HasSuffix(nr.URL.Path, "/query_range") {
nr.URL.Path = nr.URL.Path[0 : len(nr.URL.Path)-6]
}
v, _, _ := request.GetRequestValues(nr)
v, _, _ := params.GetRequestValues(nr)
v.Del(upStart)
v.Del(upEnd)
v.Del(upStep)
Expand All @@ -392,7 +392,7 @@ func (c *TestClient) FastForwardRequest(r *http.Request) (*http.Request, error)
}
v.Set("time", strconv.FormatInt(c.fftime.Unix(), 10))

request.SetRequestValues(nr, v)
params.SetRequestValues(nr, v)
return nr, nil
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/proxy/engines/document.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func (d *HTTPDocument) ParsePartialContentBody(resp *http.Response, body []byte,
}
}
} else if strings.HasPrefix(ct, headers.ValueMultipartByteRanges) {
p, ct, r, cl, err := byterange.ParseMultipartRangeResponseBody(ioutil.NopCloser(bytes.NewBuffer(body)), ct)
p, ct, r, cl, err := byterange.ParseMultipartRangeResponseBody(ioutil.NopCloser(bytes.NewReader(body)), ct)
if err == nil {
if d.RangeParts == nil {
d.Ranges = r
Expand Down
6 changes: 4 additions & 2 deletions pkg/proxy/engines/httpproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ func PrepareFetchReader(r *http.Request) (io.ReadCloser, *http.Response, int64)

if pc != nil {
headers.UpdateHeaders(r.Header, pc.RequestHeaders)
params.UpdateParams(r.URL.Query(), pc.RequestParams)
qp, _, _ := params.GetRequestValues(r)
params.UpdateParams(qp, pc.RequestParams)
params.SetRequestValues(r, qp)
}

r.Close = false
Expand Down Expand Up @@ -238,7 +240,7 @@ func PrepareFetchReader(r *http.Request) (io.ReadCloser, *http.Response, int64)
if hasCustomResponseBody {
// Since we are not responding with the actual upstream response body, close it here
resp.Body.Close()
rc = ioutil.NopCloser(bytes.NewBuffer(pc.ResponseBodyBytes))
rc = ioutil.NopCloser(bytes.NewReader(pc.ResponseBodyBytes))
} else {
rc = resp.Body
}
Expand Down
17 changes: 9 additions & 8 deletions pkg/proxy/engines/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/tricksterproxy/trickster/pkg/proxy/errors"
"github.com/tricksterproxy/trickster/pkg/proxy/headers"
"github.com/tricksterproxy/trickster/pkg/proxy/methods"
"github.com/tricksterproxy/trickster/pkg/proxy/params"
"github.com/tricksterproxy/trickster/pkg/proxy/request"
"github.com/tricksterproxy/trickster/pkg/util/md5"
)
Expand All @@ -43,29 +44,29 @@ func (pr *proxyRequest) DeriveCacheKey(templateURL *url.URL, extra string) strin
return md5.Checksum(pr.URL.Path + extra)
}

var params url.Values
var qp url.Values
r := pr.Request

if pr.upstreamRequest != nil {
r = pr.upstreamRequest
if r.URL == nil {
r.URL = pr.URL
params = pr.URL.Query()
qp = pr.URL.Query()
}
}

var b []byte
if templateURL != nil {
params = templateURL.Query()
qp = templateURL.Query()
} else {
var s string
params, s, _ = request.GetRequestValues(r)
qp, s, _ = params.GetRequestValues(r)
b = []byte(s)
}

if pc.KeyHasher != nil && len(pc.KeyHasher) == 1 {
var k string
k, r.Body = pc.KeyHasher[0](r.URL.Path, params, r.Header, r.Body, extra)
k, r.Body = pc.KeyHasher[0](r.URL.Path, qp, r.Header, r.Body, extra)
return k
}

Expand All @@ -79,12 +80,12 @@ func (pr *proxyRequest) DeriveCacheKey(templateURL *url.URL, extra string) strin
vals = append(vals, fmt.Sprintf("%s.%s.", "method", r.Method))

if len(pc.CacheKeyParams) == 1 && pc.CacheKeyParams[0] == "*" {
for p := range params {
vals = append(vals, fmt.Sprintf("%s.%s.", p, params.Get(p)))
for p := range qp {
vals = append(vals, fmt.Sprintf("%s.%s.", p, qp.Get(p)))
}
} else {
for _, p := range pc.CacheKeyParams {
if v := params.Get(p); v != "" {
if v := qp.Get(p); v != "" {
vals = append(vals, fmt.Sprintf("%s.%s.", p, v))
}
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/proxy/engines/objectproxycache.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ func handleCachePartialHit(pr *proxyRequest) error {
err := d.FulfillContentBody()

if err == nil {
pr.upstreamResponse.Body = ioutil.NopCloser(bytes.NewBuffer(d.Body))
pr.upstreamResponse.Body = ioutil.NopCloser(bytes.NewReader(d.Body))
pr.upstreamResponse.Header.Set(headers.NameContentType, d.ContentType)
pr.upstreamReader = pr.upstreamResponse.Body
} else {
h, b := d.RangeParts.ExtractResponseRange(pr.wantedRanges, d.ContentLength, d.ContentType, nil)

headers.Merge(pr.upstreamResponse.Header, h)
pr.upstreamReader = ioutil.NopCloser(bytes.NewBuffer(b))
pr.upstreamReader = ioutil.NopCloser(bytes.NewReader(b))
}

} else if d != nil {
Expand Down Expand Up @@ -210,7 +210,7 @@ func handleCacheRevalidationResponse(pr *proxyRequest) error {
pr.upstreamResponse.StatusCode = pr.cacheDocument.StatusCode
pr.writeToCache = true
pr.store()
pr.upstreamReader = bytes.NewBuffer(pr.cacheDocument.Body)
pr.upstreamReader = bytes.NewReader(pr.cacheDocument.Body)
return handleTrueCacheHit(pr)
}

Expand All @@ -235,9 +235,9 @@ func handleTrueCacheHit(pr *proxyRequest) error {
if pr.wantsRanges {
h, b := d.RangeParts.ExtractResponseRange(pr.wantedRanges, d.ContentLength, d.ContentType, d.Body)
headers.Merge(pr.upstreamResponse.Header, h)
pr.upstreamReader = bytes.NewBuffer(b)
pr.upstreamReader = bytes.NewReader(b)
} else {
pr.upstreamReader = bytes.NewBuffer(d.Body)
pr.upstreamReader = bytes.NewReader(d.Body)
}

return handleResponse(pr)
Expand Down
11 changes: 6 additions & 5 deletions pkg/proxy/engines/proxy_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/tricksterproxy/trickster/pkg/locks"
tctx "github.com/tricksterproxy/trickster/pkg/proxy/context"
"github.com/tricksterproxy/trickster/pkg/proxy/headers"
"github.com/tricksterproxy/trickster/pkg/proxy/methods"
"github.com/tricksterproxy/trickster/pkg/proxy/ranges/byterange"
"github.com/tricksterproxy/trickster/pkg/proxy/request"
tspan "github.com/tricksterproxy/trickster/pkg/tracing/span"
Expand Down Expand Up @@ -504,7 +505,7 @@ func (pr *proxyRequest) prepareResponse() {
if pr.cachingPolicy.IsClientFresh {
// 304 on an If-None-Match only applies to GET/HEAD requests
// this bit will convert an INM-based 304 to a 412 on non-GET/HEAD
if (pr.Method != http.MethodGet && pr.Method != http.MethodHead) &&
if !methods.IsCacheable(pr.Method) &&
pr.cachingPolicy.HasIfNoneMatch && !pr.cachingPolicy.IfNoneMatchResult {
pr.upstreamResponse.StatusCode = http.StatusPreconditionFailed
} else {
Expand Down Expand Up @@ -548,7 +549,7 @@ func (pr *proxyRequest) prepareResponse() {
pr.trueContentType = d.ContentType
h, pr.responseBody = d.RangeParts.ExtractResponseRange(pr.wantedRanges, d.ContentLength, d.ContentType, d.Body)
headers.Merge(resp.Header, h)
pr.upstreamReader = bytes.NewBuffer(pr.responseBody)
pr.upstreamReader = bytes.NewReader(pr.responseBody)
} else if !pr.wantsRanges {
if resp.StatusCode == http.StatusPartialContent {
resp.StatusCode = http.StatusOK
Expand Down Expand Up @@ -703,19 +704,19 @@ func (pr *proxyRequest) reconstituteResponses() {
if bodyFromParts = len(parts.Ranges) > 1; !bodyFromParts {
err := parts.FulfillContentBody()
if bodyFromParts = err != nil; !bodyFromParts {
pr.upstreamReader = bytes.NewBuffer(parts.Body)
pr.upstreamReader = bytes.NewReader(parts.Body)
resp.StatusCode = http.StatusOK
pr.cacheBuffer = bytes.NewBuffer(parts.Body)
}
}
} else {
pr.upstreamReader = bytes.NewBuffer(parts.Body)
pr.upstreamReader = bytes.NewReader(parts.Body)
}

if bodyFromParts {
h, b := parts.RangeParts.Body(parts.ContentLength, parts.ContentType)
headers.Merge(resp.Header, h)
pr.upstreamReader = bytes.NewBuffer(b)
pr.upstreamReader = bytes.NewReader(b)
}
}

Expand Down
54 changes: 52 additions & 2 deletions pkg/proxy/methods/methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,45 @@ package methods
import "net/http"

const (
get uint16 = 1 << iota
head
post
put
patch
delete
options
connect
trace
purge
)

const (
cacheableMethods = get + head
bodyMethods = post + put + patch
uncacheableMethods = bodyMethods + delete + options + connect + trace + purge
allMethods = cacheableMethods + uncacheableMethods
)

const (
// Methods not currently in the base golang http package

// MethodPurge is the PURGE HTTP Method
MethodPurge = "PURGE"
)

var methodsMap = map[string]uint16{
http.MethodGet: get,
http.MethodHead: head,
http.MethodPost: post,
http.MethodPut: put,
http.MethodPatch: patch,
http.MethodDelete: delete,
http.MethodOptions: options,
http.MethodConnect: connect,
http.MethodTrace: trace,
MethodPurge: purge,
}

// AllHTTPMethods returns a list of all known HTTP methods
func AllHTTPMethods() []string {
return []string{http.MethodGet, http.MethodHead, http.MethodPost, http.MethodPut, http.MethodDelete,
Expand All @@ -46,10 +78,28 @@ func UncacheableHTTPMethods() []string {

// IsCacheable returns true if the method is HEAD or GET
func IsCacheable(method string) bool {
return method == http.MethodGet || method == http.MethodHead
if m, ok := methodsMap[method]; ok {
return (cacheableMethods&m != 0)
}
return false
}

// HasBody returns true if the method is POST, PUT or PATCH
func HasBody(method string) bool {
return method == http.MethodPost || method == http.MethodPatch || method == http.MethodPut
if m, ok := methodsMap[method]; ok {
return (bodyMethods&m != 0)
}
return false
}

// MethodMask returns the integer representation of the collection of methods
// based on the iota bitmask defined above
func MethodMask(methods ...string) uint16 {
var i uint16
for _, ms := range methods {
if m, ok := methodsMap[ms]; ok {
i ^= m
}
}
return i
}
12 changes: 12 additions & 0 deletions pkg/proxy/methods/methods_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ func TestIsCacheable(t *testing.T) {
if IsCacheable(http.MethodPut) {
t.Error("expected false")
}
if IsCacheable("invalid_method") {
t.Error("expected false")
}
}

func TestHasBody(t *testing.T) {
Expand All @@ -61,4 +64,13 @@ func TestHasBody(t *testing.T) {
if !HasBody(http.MethodPut) {
t.Error("expected true")
}
if HasBody("invalid_method") {
t.Error("expected false")
}
}

func TestMethodMask(t *testing.T) {
if v := MethodMask(http.MethodGet); v != 1 {
t.Errorf("expected 1 got %d", v)
}
}
27 changes: 22 additions & 5 deletions pkg/proxy/origins/influxdb/handler_health.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
package influxdb

import (
"bytes"
"context"
"net/http"
"net/url"

tctx "github.com/tricksterproxy/trickster/pkg/proxy/context"
"github.com/tricksterproxy/trickster/pkg/proxy/engines"
"github.com/tricksterproxy/trickster/pkg/proxy/headers"
"github.com/tricksterproxy/trickster/pkg/proxy/methods"
"github.com/tricksterproxy/trickster/pkg/proxy/request"
"github.com/tricksterproxy/trickster/pkg/proxy/urls"
)
Expand All @@ -41,11 +43,16 @@ func (c *Client) HealthHandler(w http.ResponseWriter, r *http.Request) {
return
}

req, _ := http.NewRequest(c.healthMethod, c.healthURL.String(), nil)
req, _ := http.NewRequest(c.healthMethod, c.healthURL.String(), c.healthBody)
rsc := request.GetResources(r)
req = req.WithContext(tctx.WithHealthCheckFlag(tctx.WithResources(context.Background(), rsc), true))

req.Header = c.healthHeaders
if c.healthHeaders != nil {
c.healthHeaderLock.Lock()
req.Header = c.healthHeaders.Clone()
c.healthHeaderLock.Unlock()
}

engines.DoProxy(w, req, true)
}

Expand All @@ -64,13 +71,23 @@ func (c *Client) populateHeathCheckRequestValues() {
oc.HealthCheckQuery = q.Encode()
}

c.healthMethod = oc.HealthCheckVerb

c.healthURL = urls.Clone(c.baseUpstreamURL)
c.healthURL.Path += oc.HealthCheckUpstreamPath
c.healthURL.RawQuery = oc.HealthCheckQuery
c.healthMethod = oc.HealthCheckVerb

if oc.HealthCheckHeaders != nil {
if methods.HasBody(oc.HealthCheckVerb) && oc.HealthCheckQuery != "" {
c.healthHeaders = http.Header{}
c.healthHeaders.Set(headers.NameContentType, headers.ValueXFormURLEncoded)
c.healthBody = bytes.NewReader([]byte(oc.HealthCheckQuery))
} else {
c.healthURL.RawQuery = oc.HealthCheckQuery
}

if oc.HealthCheckHeaders != nil {
if c.healthHeaders == nil {
c.healthHeaders = http.Header{}
}
headers.UpdateHeaders(c.healthHeaders, oc.HealthCheckHeaders)
}
}
Loading

0 comments on commit c11c41b

Please sign in to comment.