From 73ad7fd65c4d65e8703c719d6c6fbeb4b8baff58 Mon Sep 17 00:00:00 2001 From: VinnyHC Date: Fri, 25 Feb 2022 08:54:39 -0800 Subject: [PATCH] dry --- api/sys_raft.go | 238 ++++++++++++++++++------------------------------ 1 file changed, 90 insertions(+), 148 deletions(-) diff --git a/api/sys_raft.go b/api/sys_raft.go index 685717401dd3..89c49d8082f8 100644 --- a/api/sys_raft.go +++ b/api/sys_raft.go @@ -138,81 +138,10 @@ func (c *Sys) RaftSnapshot(snapWriter io.Writer) error { r := c.c.NewRequest("GET", "/v1/sys/storage/raft/snapshot") r.URL.RawQuery = r.Params.Encode() - req, err := http.NewRequest(http.MethodGet, r.URL.RequestURI(), nil) - if err != nil { - return err - } - - req.URL.User = r.URL.User - req.URL.Scheme = r.URL.Scheme - req.URL.Host = r.URL.Host - req.Host = r.URL.Host - - if r.Headers != nil { - for header, vals := range r.Headers { - for _, val := range vals { - req.Header.Add(header, val) - } - } - } - - if len(r.ClientToken) != 0 { - req.Header.Set(consts.AuthHeaderName, r.ClientToken) - } - - if len(r.WrapTTL) != 0 { - req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL) - } - - if len(r.MFAHeaderVals) != 0 { - for _, mfaHeaderVal := range r.MFAHeaderVals { - req.Header.Add("X-Vault-MFA", mfaHeaderVal) - } - } - - if r.PolicyOverride { - req.Header.Set("X-Vault-Policy-Override", "true") - } - - // Avoiding the use of RawRequestWithContext which reads the response body - // to determine if the body contains error message. - var result *Response - resp, err := c.c.config.HttpClient.Do(req) - if err != nil { - return err - } - - if resp == nil { - return nil - } - - // Check for a redirect, only allowing for a single redirect - if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 { - // Parse the updated location - respLoc, err := resp.Location() - if err != nil { - return err - } - - // Ensure a protocol downgrade doesn't happen - if req.URL.Scheme == "https" && respLoc.Scheme != "https" { - return fmt.Errorf("redirect would cause protocol downgrade") - } - - // Update the request - req.URL = respLoc + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() - // Retry the request - resp, err = c.c.config.HttpClient.Do(req) - if err != nil { - return err - } - } - - result = &Response{Response: resp} - if err := result.Error(); err != nil { - return err - } + resp, err := c.requestWithContext(ctx, r, nil) // Make sure that the last file in the archive, SHA256SUMS.sealed, is present // and non-empty. This is to catch cases where the snapshot failed midstream, @@ -272,8 +201,7 @@ func (c *Sys) RaftSnapshot(snapWriter io.Writer) error { } // RaftSnapshotRestore reads the snapshot from the io.Reader and installs that -// snapshot, returning the cluster to the state defined by it. This avoids the use of -// RawRequestWithContext which copies the body (leading to possible OOMs) for retrying +// snapshot, returning the cluster to the state defined by it. func (c *Sys) RaftSnapshotRestore(snapReader io.Reader, force bool) error { path := "/v1/sys/storage/raft/snapshot" if force { @@ -281,84 +209,15 @@ func (c *Sys) RaftSnapshotRestore(snapReader io.Reader, force bool) error { } r := c.c.NewRequest(http.MethodPost, path) - r.URL.RawQuery = r.Params.Encode() - req, err := http.NewRequest(http.MethodPost, r.URL.RequestURI(), snapReader) - if err != nil { - return err - } - - req.URL.User = r.URL.User - req.URL.Scheme = r.URL.Scheme - req.URL.Host = r.URL.Host - req.Host = r.URL.Host - - if r.Headers != nil { - for header, vals := range r.Headers { - for _, val := range vals { - req.Header.Add(header, val) - } - } - } - - if len(r.ClientToken) != 0 { - req.Header.Set(consts.AuthHeaderName, r.ClientToken) - } - - if len(r.WrapTTL) != 0 { - req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL) - } - - if len(r.MFAHeaderVals) != 0 { - for _, mfaHeaderVal := range r.MFAHeaderVals { - req.Header.Add("X-Vault-MFA", mfaHeaderVal) - } - } - - if r.PolicyOverride { - req.Header.Set("X-Vault-Policy-Override", "true") - } - - var result *Response - resp, err := c.c.config.HttpClient.Do(req) - defer resp.Body.Close() + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + _, err := c.requestWithContext(ctx, r, snapReader) if err != nil { return err } - if resp == nil { - return nil - } - - // Check for a redirect, only allowing for a single redirect - if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 { - // Parse the updated location - respLoc, err := resp.Location() - if err != nil { - return err - } - - // Ensure a protocol downgrade doesn't happen - if req.URL.Scheme == "https" && respLoc.Scheme != "https" { - return fmt.Errorf("redirect would cause protocol downgrade") - } - - // Update the request - req.URL = respLoc - - // Retry the request - resp, err = c.c.config.HttpClient.Do(req) - if err != nil { - return err - } - } - - result = &Response{Response: resp} - if err := result.Error(); err != nil { - return err - } - return nil } @@ -456,3 +315,86 @@ func (c *Sys) PutRaftAutopilotConfiguration(opts *AutopilotConfig) error { return nil } + +// requestWithContext avoids the use of the go-retryable library and is useful when +// making requests where the req or resp body is likely larger than available memory +func (c *Sys) requestWithContext(ctx context.Context, r *Request, body io.Reader) (*Response, error) { + req, err := http.NewRequest(r.Method, r.URL.RequestURI(), body) + if err != nil { + return nil, err + } + + req.URL.User = r.URL.User + req.URL.Scheme = r.URL.Scheme + req.URL.Host = r.URL.Host + req.Host = r.URL.Host + + if r.Headers != nil { + for header, vals := range r.Headers { + for _, val := range vals { + req.Header.Add(header, val) + } + } + } + + if len(r.ClientToken) != 0 { + req.Header.Set(consts.AuthHeaderName, r.ClientToken) + } + + if len(r.WrapTTL) != 0 { + req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL) + } + + if len(r.MFAHeaderVals) != 0 { + for _, mfaHeaderVal := range r.MFAHeaderVals { + req.Header.Add("X-Vault-MFA", mfaHeaderVal) + } + } + + if r.PolicyOverride { + req.Header.Set("X-Vault-Policy-Override", "true") + } + + var result *Response + + resp, err := c.c.config.HttpClient.Do(req) + if err != nil { + return nil, err + } + + if resp == nil { + return nil, err + } + + defer resp.Body.Close() + + // Check for a redirect, only allowing for a single redirect + if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 { + // Parse the updated location + respLoc, err := resp.Location() + if err != nil { + return nil, err + } + + // Ensure a protocol downgrade doesn't happen + if req.URL.Scheme == "https" && respLoc.Scheme != "https" { + return nil, fmt.Errorf("redirect would cause protocol downgrade") + } + + // Update the request + req.URL = respLoc + + // Retry the request + resp, err = c.c.config.HttpClient.Do(req) + if err != nil { + return nil, err + } + } + + result = &Response{Response: resp} + if err := result.Error(); err != nil { + return nil, err + } + + return result, nil +}