Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V1.1 more post support #454

Merged
merged 5 commits into from
Jun 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions pkg/proxy/engines/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ import (
"github.com/tricksterproxy/trickster/pkg/proxy/errors"
"github.com/tricksterproxy/trickster/pkg/proxy/headers"
oo "github.com/tricksterproxy/trickster/pkg/proxy/origins/options"
"github.com/tricksterproxy/trickster/pkg/proxy/params"
po "github.com/tricksterproxy/trickster/pkg/proxy/paths/options"
"github.com/tricksterproxy/trickster/pkg/proxy/request"
tt "github.com/tricksterproxy/trickster/pkg/proxy/timeconv"
"github.com/tricksterproxy/trickster/pkg/sort/times"
"github.com/tricksterproxy/trickster/pkg/timeseries"
Expand Down Expand Up @@ -283,7 +283,7 @@ func parseDuration(input string) (time.Duration, error) {
func (c *TestClient) ParseTimeRangeQuery(r *http.Request) (*timeseries.TimeRangeQuery, error) {

trq := &timeseries.TimeRangeQuery{Extent: timeseries.Extent{}}
qp := r.URL.Query()
qp, _, _ := params.GetRequestValues(r)

trq.Statement = qp.Get(upQuery)
if trq.Statement == "" {
Expand Down Expand Up @@ -359,10 +359,10 @@ func (c *TestClient) BuildUpstreamURL(r *http.Request) *url.URL {

// SetExtent will change the upstream request query to use the provided Extent
func (c *TestClient) SetExtent(r *http.Request, trq *timeseries.TimeRangeQuery, extent *timeseries.Extent) {
v, _, _ := request.GetRequestValues(r)
v, _, _ := params.GetRequestValues(r)
v.Set(upStart, strconv.FormatInt(extent.Start.Unix(), 10))
v.Set(upEnd, strconv.FormatInt(extent.End.Unix(), 10))
request.SetRequestValues(r, v)
params.SetRequestValues(r, v)
}

// FastForwardRequest returns an *http.Request crafted to collect Fast Forward
Expand All @@ -382,7 +382,7 @@ func (c *TestClient) FastForwardRequest(r *http.Request) (*http.Request, error)
if strings.HasSuffix(nr.URL.Path, "/query_range") {
nr.URL.Path = nr.URL.Path[0 : len(nr.URL.Path)-6]
}
v, _, _ := request.GetRequestValues(nr)
v, _, _ := params.GetRequestValues(nr)
v.Del(upStart)
v.Del(upEnd)
v.Del(upStep)
Expand All @@ -392,7 +392,7 @@ func (c *TestClient) FastForwardRequest(r *http.Request) (*http.Request, error)
}
v.Set("time", strconv.FormatInt(c.fftime.Unix(), 10))

request.SetRequestValues(nr, v)
params.SetRequestValues(nr, v)
return nr, nil
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/proxy/engines/document.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func (d *HTTPDocument) ParsePartialContentBody(resp *http.Response, body []byte,
}
}
} else if strings.HasPrefix(ct, headers.ValueMultipartByteRanges) {
p, ct, r, cl, err := byterange.ParseMultipartRangeResponseBody(ioutil.NopCloser(bytes.NewBuffer(body)), ct)
p, ct, r, cl, err := byterange.ParseMultipartRangeResponseBody(ioutil.NopCloser(bytes.NewReader(body)), ct)
if err == nil {
if d.RangeParts == nil {
d.Ranges = r
Expand Down
6 changes: 4 additions & 2 deletions pkg/proxy/engines/httpproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ func PrepareFetchReader(r *http.Request) (io.ReadCloser, *http.Response, int64)

if pc != nil {
headers.UpdateHeaders(r.Header, pc.RequestHeaders)
params.UpdateParams(r.URL.Query(), pc.RequestParams)
qp, _, _ := params.GetRequestValues(r)
params.UpdateParams(qp, pc.RequestParams)
params.SetRequestValues(r, qp)
}

r.Close = false
Expand Down Expand Up @@ -238,7 +240,7 @@ func PrepareFetchReader(r *http.Request) (io.ReadCloser, *http.Response, int64)
if hasCustomResponseBody {
// Since we are not responding with the actual upstream response body, close it here
resp.Body.Close()
rc = ioutil.NopCloser(bytes.NewBuffer(pc.ResponseBodyBytes))
rc = ioutil.NopCloser(bytes.NewReader(pc.ResponseBodyBytes))
} else {
rc = resp.Body
}
Expand Down
17 changes: 9 additions & 8 deletions pkg/proxy/engines/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/tricksterproxy/trickster/pkg/proxy/errors"
"github.com/tricksterproxy/trickster/pkg/proxy/headers"
"github.com/tricksterproxy/trickster/pkg/proxy/methods"
"github.com/tricksterproxy/trickster/pkg/proxy/params"
"github.com/tricksterproxy/trickster/pkg/proxy/request"
"github.com/tricksterproxy/trickster/pkg/util/md5"
)
Expand All @@ -43,29 +44,29 @@ func (pr *proxyRequest) DeriveCacheKey(templateURL *url.URL, extra string) strin
return md5.Checksum(pr.URL.Path + extra)
}

var params url.Values
var qp url.Values
r := pr.Request

if pr.upstreamRequest != nil {
r = pr.upstreamRequest
if r.URL == nil {
r.URL = pr.URL
params = pr.URL.Query()
qp = pr.URL.Query()
}
}

var b []byte
if templateURL != nil {
params = templateURL.Query()
qp = templateURL.Query()
} else {
var s string
params, s, _ = request.GetRequestValues(r)
qp, s, _ = params.GetRequestValues(r)
b = []byte(s)
}

if pc.KeyHasher != nil && len(pc.KeyHasher) == 1 {
var k string
k, r.Body = pc.KeyHasher[0](r.URL.Path, params, r.Header, r.Body, extra)
k, r.Body = pc.KeyHasher[0](r.URL.Path, qp, r.Header, r.Body, extra)
return k
}

Expand All @@ -79,12 +80,12 @@ func (pr *proxyRequest) DeriveCacheKey(templateURL *url.URL, extra string) strin
vals = append(vals, fmt.Sprintf("%s.%s.", "method", r.Method))

if len(pc.CacheKeyParams) == 1 && pc.CacheKeyParams[0] == "*" {
for p := range params {
vals = append(vals, fmt.Sprintf("%s.%s.", p, params.Get(p)))
for p := range qp {
vals = append(vals, fmt.Sprintf("%s.%s.", p, qp.Get(p)))
}
} else {
for _, p := range pc.CacheKeyParams {
if v := params.Get(p); v != "" {
if v := qp.Get(p); v != "" {
vals = append(vals, fmt.Sprintf("%s.%s.", p, v))
}
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/proxy/engines/objectproxycache.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ func handleCachePartialHit(pr *proxyRequest) error {
err := d.FulfillContentBody()

if err == nil {
pr.upstreamResponse.Body = ioutil.NopCloser(bytes.NewBuffer(d.Body))
pr.upstreamResponse.Body = ioutil.NopCloser(bytes.NewReader(d.Body))
pr.upstreamResponse.Header.Set(headers.NameContentType, d.ContentType)
pr.upstreamReader = pr.upstreamResponse.Body
} else {
h, b := d.RangeParts.ExtractResponseRange(pr.wantedRanges, d.ContentLength, d.ContentType, nil)

headers.Merge(pr.upstreamResponse.Header, h)
pr.upstreamReader = ioutil.NopCloser(bytes.NewBuffer(b))
pr.upstreamReader = ioutil.NopCloser(bytes.NewReader(b))
}

} else if d != nil {
Expand Down Expand Up @@ -210,7 +210,7 @@ func handleCacheRevalidationResponse(pr *proxyRequest) error {
pr.upstreamResponse.StatusCode = pr.cacheDocument.StatusCode
pr.writeToCache = true
pr.store()
pr.upstreamReader = bytes.NewBuffer(pr.cacheDocument.Body)
pr.upstreamReader = bytes.NewReader(pr.cacheDocument.Body)
return handleTrueCacheHit(pr)
}

Expand All @@ -235,9 +235,9 @@ func handleTrueCacheHit(pr *proxyRequest) error {
if pr.wantsRanges {
h, b := d.RangeParts.ExtractResponseRange(pr.wantedRanges, d.ContentLength, d.ContentType, d.Body)
headers.Merge(pr.upstreamResponse.Header, h)
pr.upstreamReader = bytes.NewBuffer(b)
pr.upstreamReader = bytes.NewReader(b)
} else {
pr.upstreamReader = bytes.NewBuffer(d.Body)
pr.upstreamReader = bytes.NewReader(d.Body)
}

return handleResponse(pr)
Expand Down
11 changes: 6 additions & 5 deletions pkg/proxy/engines/proxy_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/tricksterproxy/trickster/pkg/locks"
tctx "github.com/tricksterproxy/trickster/pkg/proxy/context"
"github.com/tricksterproxy/trickster/pkg/proxy/headers"
"github.com/tricksterproxy/trickster/pkg/proxy/methods"
"github.com/tricksterproxy/trickster/pkg/proxy/ranges/byterange"
"github.com/tricksterproxy/trickster/pkg/proxy/request"
tspan "github.com/tricksterproxy/trickster/pkg/tracing/span"
Expand Down Expand Up @@ -504,7 +505,7 @@ func (pr *proxyRequest) prepareResponse() {
if pr.cachingPolicy.IsClientFresh {
// 304 on an If-None-Match only applies to GET/HEAD requests
// this bit will convert an INM-based 304 to a 412 on non-GET/HEAD
if (pr.Method != http.MethodGet && pr.Method != http.MethodHead) &&
if !methods.IsCacheable(pr.Method) &&
pr.cachingPolicy.HasIfNoneMatch && !pr.cachingPolicy.IfNoneMatchResult {
pr.upstreamResponse.StatusCode = http.StatusPreconditionFailed
} else {
Expand Down Expand Up @@ -548,7 +549,7 @@ func (pr *proxyRequest) prepareResponse() {
pr.trueContentType = d.ContentType
h, pr.responseBody = d.RangeParts.ExtractResponseRange(pr.wantedRanges, d.ContentLength, d.ContentType, d.Body)
headers.Merge(resp.Header, h)
pr.upstreamReader = bytes.NewBuffer(pr.responseBody)
pr.upstreamReader = bytes.NewReader(pr.responseBody)
} else if !pr.wantsRanges {
if resp.StatusCode == http.StatusPartialContent {
resp.StatusCode = http.StatusOK
Expand Down Expand Up @@ -703,19 +704,19 @@ func (pr *proxyRequest) reconstituteResponses() {
if bodyFromParts = len(parts.Ranges) > 1; !bodyFromParts {
err := parts.FulfillContentBody()
if bodyFromParts = err != nil; !bodyFromParts {
pr.upstreamReader = bytes.NewBuffer(parts.Body)
pr.upstreamReader = bytes.NewReader(parts.Body)
resp.StatusCode = http.StatusOK
pr.cacheBuffer = bytes.NewBuffer(parts.Body)
}
}
} else {
pr.upstreamReader = bytes.NewBuffer(parts.Body)
pr.upstreamReader = bytes.NewReader(parts.Body)
}

if bodyFromParts {
h, b := parts.RangeParts.Body(parts.ContentLength, parts.ContentType)
headers.Merge(resp.Header, h)
pr.upstreamReader = bytes.NewBuffer(b)
pr.upstreamReader = bytes.NewReader(b)
}
}

Expand Down
54 changes: 52 additions & 2 deletions pkg/proxy/methods/methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,45 @@ package methods
import "net/http"

const (
get uint16 = 1 << iota
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cool!

head
post
put
patch
delete
options
connect
trace
purge
)

const (
cacheableMethods = get + head
bodyMethods = post + put + patch
uncacheableMethods = bodyMethods + delete + options + connect + trace + purge
allMethods = cacheableMethods + uncacheableMethods
)

const (
// Methods not currently in the base golang http package

// MethodPurge is the PURGE HTTP Method
MethodPurge = "PURGE"
)

var methodsMap = map[string]uint16{
http.MethodGet: get,
http.MethodHead: head,
http.MethodPost: post,
http.MethodPut: put,
http.MethodPatch: patch,
http.MethodDelete: delete,
http.MethodOptions: options,
http.MethodConnect: connect,
http.MethodTrace: trace,
MethodPurge: purge,
}

// AllHTTPMethods returns a list of all known HTTP methods
func AllHTTPMethods() []string {
return []string{http.MethodGet, http.MethodHead, http.MethodPost, http.MethodPut, http.MethodDelete,
Expand All @@ -46,10 +78,28 @@ func UncacheableHTTPMethods() []string {

// IsCacheable returns true if the method is HEAD or GET
func IsCacheable(method string) bool {
return method == http.MethodGet || method == http.MethodHead
if m, ok := methodsMap[method]; ok {
return (cacheableMethods&m != 0)
}
return false
}

// HasBody returns true if the method is POST, PUT or PATCH
func HasBody(method string) bool {
return method == http.MethodPost || method == http.MethodPatch || method == http.MethodPut
if m, ok := methodsMap[method]; ok {
return (bodyMethods&m != 0)
}
return false
}

// MethodMask returns the integer representation of the collection of methods
// based on the iota bitmask defined above
func MethodMask(methods ...string) uint16 {
var i uint16
for _, ms := range methods {
if m, ok := methodsMap[ms]; ok {
i ^= m
}
}
return i
}
12 changes: 12 additions & 0 deletions pkg/proxy/methods/methods_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ func TestIsCacheable(t *testing.T) {
if IsCacheable(http.MethodPut) {
t.Error("expected false")
}
if IsCacheable("invalid_method") {
t.Error("expected false")
}
}

func TestHasBody(t *testing.T) {
Expand All @@ -61,4 +64,13 @@ func TestHasBody(t *testing.T) {
if !HasBody(http.MethodPut) {
t.Error("expected true")
}
if HasBody("invalid_method") {
t.Error("expected false")
}
}

func TestMethodMask(t *testing.T) {
if v := MethodMask(http.MethodGet); v != 1 {
t.Errorf("expected 1 got %d", v)
}
}
27 changes: 22 additions & 5 deletions pkg/proxy/origins/influxdb/handler_health.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
package influxdb

import (
"bytes"
"context"
"net/http"
"net/url"

tctx "github.com/tricksterproxy/trickster/pkg/proxy/context"
"github.com/tricksterproxy/trickster/pkg/proxy/engines"
"github.com/tricksterproxy/trickster/pkg/proxy/headers"
"github.com/tricksterproxy/trickster/pkg/proxy/methods"
"github.com/tricksterproxy/trickster/pkg/proxy/request"
"github.com/tricksterproxy/trickster/pkg/proxy/urls"
)
Expand All @@ -41,11 +43,16 @@ func (c *Client) HealthHandler(w http.ResponseWriter, r *http.Request) {
return
}

req, _ := http.NewRequest(c.healthMethod, c.healthURL.String(), nil)
req, _ := http.NewRequest(c.healthMethod, c.healthURL.String(), c.healthBody)
rsc := request.GetResources(r)
req = req.WithContext(tctx.WithHealthCheckFlag(tctx.WithResources(context.Background(), rsc), true))

req.Header = c.healthHeaders
if c.healthHeaders != nil {
c.healthHeaderLock.Lock()
req.Header = c.healthHeaders.Clone()
c.healthHeaderLock.Unlock()
}

engines.DoProxy(w, req, true)
}

Expand All @@ -64,13 +71,23 @@ func (c *Client) populateHeathCheckRequestValues() {
oc.HealthCheckQuery = q.Encode()
}

c.healthMethod = oc.HealthCheckVerb

c.healthURL = urls.Clone(c.baseUpstreamURL)
c.healthURL.Path += oc.HealthCheckUpstreamPath
c.healthURL.RawQuery = oc.HealthCheckQuery
c.healthMethod = oc.HealthCheckVerb

if oc.HealthCheckHeaders != nil {
if methods.HasBody(oc.HealthCheckVerb) && oc.HealthCheckQuery != "" {
c.healthHeaders = http.Header{}
c.healthHeaders.Set(headers.NameContentType, headers.ValueXFormURLEncoded)
c.healthBody = bytes.NewReader([]byte(oc.HealthCheckQuery))
} else {
c.healthURL.RawQuery = oc.HealthCheckQuery
}

if oc.HealthCheckHeaders != nil {
if c.healthHeaders == nil {
c.healthHeaders = http.Header{}
}
headers.UpdateHeaders(c.healthHeaders, oc.HealthCheckHeaders)
}
}
Loading