From 21e801edb21e8de38e422a3dec9534545ed7fe1a Mon Sep 17 00:00:00 2001 From: Soule BA Date: Mon, 26 Feb 2024 15:42:32 +0100 Subject: [PATCH] Enable pulling large files in parallel This is an attempt to better meet the expectation of users that pull large files. If implemented this will permit to pull concurrently chunks of a given artifact layer. Signed-off-by: Soule BA --- oci/client/client.go | 4 +- oci/client/pull.go | 223 +++++++++++++++++++++++++++++++++-- oci/client/pull_test.go | 21 ++++ oci/client/push_pull_test.go | 1 + oci/go.mod | 4 +- oci/go.sum | 6 + 6 files changed, 248 insertions(+), 11 deletions(-) diff --git a/oci/client/client.go b/oci/client/client.go index b3cd257a..c816855a 100644 --- a/oci/client/client.go +++ b/oci/client/client.go @@ -21,13 +21,15 @@ import ( "github.com/google/go-containerregistry/pkg/crane" "github.com/google/go-containerregistry/pkg/v1/remote" +"github.com/hashicorp/go-retryablehttp" "github.com/fluxcd/pkg/oci" ) // Client holds the options for accessing remote OCI registries. type Client struct { - options []crane.Option + options []crane.Option + httpClient *retryablehttp.Client } // NewClient returns an OCI client configured with the given crane options. diff --git a/oci/client/pull.go b/oci/client/pull.go index 633bf4b3..b49892a7 100644 --- a/oci/client/pull.go +++ b/oci/client/pull.go @@ -22,13 +22,28 @@ import ( "context" "fmt" "io" + "net/http" + "net/url" "os" + "github.com/fluxcd/pkg/tar" + "github.com/google/go-containerregistry/pkg/authn" "github.com/google/go-containerregistry/pkg/crane" "github.com/google/go-containerregistry/pkg/name" - gcrv1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/hashicorp/go-retryablehttp" - "github.com/fluxcd/pkg/tar" + v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/google/go-containerregistry/pkg/v1/remote" + "github.com/google/go-containerregistry/pkg/v1/remote/transport" + "golang.org/x/sync/errgroup" +) + +const ( + // thresholdForConcurrentPull is the maximum size of a layer to be extracted in one go. + // If the layer is larger than this, it will be downloaded in chunks. + thresholdForConcurrentPull = 100 * 1024 * 1024 // 100MB + // maxConcurrentPulls is the maximum number of concurrent downloads. + maxConcurrentPulls = 10 ) var ( @@ -39,8 +54,12 @@ var ( // PullOptions contains options for pulling a layer. type PullOptions struct { - layerIndex int - layerType LayerType + layerIndex int + layerType LayerType + transport http.RoundTripper + auth authn.Authenticator + keychain authn.Keychain + concurrency int } // PullOption is a function for configuring PullOptions. @@ -60,22 +79,53 @@ func WithPullLayerIndex(i int) PullOption { } } +func WithTransport(t http.RoundTripper) PullOption { + return func(o *PullOptions) { + o.transport = t + } +} + +func WithConcurrency(c int) PullOption { + return func(o *PullOptions) { + o.concurrency = c + } +} + // Pull downloads an artifact from an OCI repository and extracts the content. // It untar or copies the content to the given outPath depending on the layerType. // If no layer type is given, it tries to determine the right type by checking compressed content of the layer. -func (c *Client) Pull(ctx context.Context, url, outPath string, opts ...PullOption) (*Metadata, error) { +func (c *Client) Pull(ctx context.Context, urlString, outPath string, opts ...PullOption) (*Metadata, error) { o := &PullOptions{ layerIndex: 0, } + o.keychain = authn.DefaultKeychain for _, opt := range opts { opt(o) } - ref, err := name.ParseReference(url) + + if o.concurrency == 0 || o.concurrency > maxConcurrentPulls { + o.concurrency = maxConcurrentPulls + } + + if o.transport == nil { + transport := remote.DefaultTransport.(*http.Transport).Clone() + o.transport = transport + } + + ref, err := name.ParseReference(urlString) if err != nil { return nil, fmt.Errorf("invalid URL: %w", err) } - img, err := crane.Pull(url, c.optionsWithContext(ctx)...) + if c.httpClient == nil { + h, err := makeHttpClient(ctx, ref.Context(), *o) + if err != nil { + return nil, err + } + c.httpClient = h + } + + img, err := crane.Pull(urlString, c.optionsWithContext(ctx)...) if err != nil { return nil, err } @@ -91,7 +141,7 @@ func (c *Client) Pull(ctx context.Context, url, outPath string, opts ...PullOpti } meta := MetadataFromAnnotations(manifest.Annotations) - meta.URL = url + meta.URL = urlString meta.Digest = ref.Context().Digest(digest.String()).String() layers, err := img.Layers() @@ -107,6 +157,34 @@ func (c *Client) Pull(ctx context.Context, url, outPath string, opts ...PullOpti return nil, fmt.Errorf("index '%d' out of bound for '%d' layers in artifact", o.layerIndex, len(layers)) } + size, err := layers[o.layerIndex].Size() + if err != nil { + return nil, fmt.Errorf("failed to get layer size: %w", err) + } + + if size > thresholdForConcurrentPull { + digest, err := layers[o.layerIndex].Digest() + if err != nil { + return nil, fmt.Errorf("parsing digest failed: %w", err) + } + u := url.URL{ + Scheme: ref.Context().Scheme(), + Host: ref.Context().RegistryStr(), + Path: fmt.Sprintf("/v2/%s/blobs/%s", ref.Context().RepositoryStr(), digest.String()), + } + ok, err := c.IsRangeRequestEnabled(ctx, u) + if err != nil { + return nil, fmt.Errorf("failed to check range request support: %w", err) + } + if ok { + err = c.concurrentExtractLayer(ctx, u, layers[o.layerIndex], outPath, digest, size, o.concurrency) + if err != nil { + return nil, err + } + return meta, nil + } + } + err = extractLayer(layers[o.layerIndex], outPath, o.layerType) if err != nil { return nil, err @@ -114,8 +192,98 @@ func (c *Client) Pull(ctx context.Context, url, outPath string, opts ...PullOpti return meta, nil } +// TO DO: handle authentication handle using keychain for authentication +func (c *Client) IsRangeRequestEnabled(ctx context.Context, u url.URL) (bool, error) { + req, err := retryablehttp.NewRequest(http.MethodHead, u.String(), nil) + if err != nil { + return false, err + } + + resp, err := c.httpClient.Do(req.WithContext(ctx)) + if err != nil { + return false, err + } + + if err := transport.CheckError(resp, http.StatusOK); err != nil { + return false, err + } + + if rangeUnit := resp.Header.Get("Accept-Ranges"); rangeUnit == "bytes" { + return true, nil + } + for k, v := range resp.Header { + fmt.Printf("Header: %s, Value: %s\n", k, v) + } + return false, nil +} + +func (c *Client) concurrentExtractLayer(ctx context.Context, u url.URL, layer v1.Layer, path string, digest v1.Hash, size int64, concurrency int) error { + chunkSize := size / int64(concurrency) + chunks := make([][]byte, concurrency+1) + diff := size % int64(concurrency) + + g, ctx := errgroup.WithContext(ctx) + for i := 0; i < concurrency; i++ { + i := i + g.Go(func() (err error) { + start, end := int64(i)*chunkSize, int64(i+1)*chunkSize + if i == concurrency-1 { + end += diff + } + req, err := retryablehttp.NewRequest(http.MethodGet, u.String(), nil) + if err != nil { + return fmt.Errorf("failed to create a new request: %w", err) + } + req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", start, end-1)) + resp, err := c.httpClient.Do(req.WithContext(ctx)) + if err != nil { + return fmt.Errorf("failed to download archive: %w", err) + } + defer resp.Body.Close() + + if err := transport.CheckError(resp, http.StatusPartialContent); err != nil { + return fmt.Errorf("failed to download archive from %s (status: %s)", u.String(), resp.Status) + } + + c, err := io.ReadAll(io.LimitReader(resp.Body, end-start)) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + chunks[i] = c + return nil + }) + } + err := g.Wait() + if err != nil { + return err + } + + content := bufio.NewReader(bytes.NewReader(bytes.Join(chunks, nil))) + d, s, err := v1.SHA256(content) + if err != nil { + return err + } + if d != digest { + return fmt.Errorf("digest mismatch: expected %s, got %s", digest, d) + } + if s != size { + return fmt.Errorf("size mismatch: expected %d, got %d", size, size) + } + + f, err := os.Create(path) + if err != nil { + return err + } + + _, err = io.Copy(f, content) + if err != nil { + return fmt.Errorf("error copying layer content: %s", err) + } + return nil +} + // extractLayer extracts the Layer to the path -func extractLayer(layer gcrv1.Layer, path string, layerType LayerType) error { +func extractLayer(layer v1.Layer, path string, layerType LayerType) error { var blob io.Reader blob, err := layer.Compressed() if err != nil { @@ -173,3 +341,40 @@ func isGzipBlob(buf *bufio.Reader) (bool, error) { } return bytes.Equal(b, gzipMagicHeader), nil } + +type resource interface { + Scheme() string + RegistryStr() string + Scope(string) string + + authn.Resource +} + +func makeHttpClient(ctx context.Context, target resource, o PullOptions) (*retryablehttp.Client, error) { + auth := o.auth + if o.keychain != nil { + kauth, err := o.keychain.Resolve(target) + if err != nil { + return nil, err + } + auth = kauth + } + + reg, ok := target.(name.Registry) + if !ok { + repo, ok := target.(name.Repository) + if !ok { + return nil, fmt.Errorf("unexpected resource: %T", target) + } + reg = repo.Registry + } + + tr, err := transport.NewWithContext(ctx, reg, auth, o.transport, []string{target.Scope(transport.PullScope)}) + if err != nil { + return nil, err + } + + h := retryablehttp.NewClient() + h.HTTPClient = &http.Client{Transport: tr} + return h, nil +} diff --git a/oci/client/pull_test.go b/oci/client/pull_test.go index 86795284..b68dd15a 100644 --- a/oci/client/pull_test.go +++ b/oci/client/pull_test.go @@ -41,6 +41,7 @@ func Test_PullAnyTarball(t *testing.T) { repo := "test-no-annotations" + randStringRunes(5) dst := fmt.Sprintf("%s/%s:%s", dockerReg, repo, tag) + fmt.Println("Pulling from:", dst) artifact := filepath.Join(t.TempDir(), "artifact.tgz") g.Expect(build(artifact, testDir, nil)).To(Succeed()) @@ -82,3 +83,23 @@ func Test_PullAnyTarball(t *testing.T) { g.Expect(extractTo + "/" + entry).To(Or(BeAnExistingFile(), BeADirectory())) } } + +func Test_PullLargeTarball(t *testing.T) { + g := NewWithT(t) + ctx := context.Background() + c := NewClient(DefaultOptions()) + dst := "vnp505/zephyr-7b-alpha:alpha" + extractTo := filepath.Join(t.TempDir(), "artifact") + m, err := c.Pull(ctx, dst, extractTo, WithPullLayerIndex(19)) + fmt.Println("Pulled from:", dst) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(m).ToNot(BeNil()) + g.Expect(m.Annotations).To(BeEmpty()) + g.Expect(m.Created).To(BeEmpty()) + g.Expect(m.Revision).To(BeEmpty()) + g.Expect(m.Source).To(BeEmpty()) + g.Expect(m.URL).To(Equal(dst)) + g.Expect(m.Digest).ToNot(BeEmpty()) + g.Expect(extractTo).ToNot(BeEmpty()) +} diff --git a/oci/client/push_pull_test.go b/oci/client/push_pull_test.go index 3c68b253..9d02f101 100644 --- a/oci/client/push_pull_test.go +++ b/oci/client/push_pull_test.go @@ -305,6 +305,7 @@ func Test_Push_Pull(t *testing.T) { g.Expect(err).ToNot(HaveOccurred()) fileInfo, err := os.Stat(tt.sourcePath) + g.Expect(err).ToNot(HaveOccurred()) // if a directory was pushed, then the created file should be a gzipped archive if fileInfo.IsDir() { bufReader := bufio.NewReader(bytes.NewReader(got)) diff --git a/oci/go.mod b/oci/go.mod index e992b575..50681ead 100644 --- a/oci/go.mod +++ b/oci/go.mod @@ -21,9 +21,11 @@ require ( github.com/fluxcd/pkg/tar v0.4.0 github.com/fluxcd/pkg/version v0.2.2 github.com/google/go-containerregistry v0.18.0 + github.com/hashicorp/go-retryablehttp v0.7.5 github.com/onsi/gomega v1.31.1 github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 github.com/sirupsen/logrus v1.9.3 + golang.org/x/sync v0.6.0 sigs.k8s.io/controller-runtime v0.16.3 ) @@ -80,6 +82,7 @@ require ( github.com/gorilla/handlers v1.5.1 // indirect github.com/gorilla/mux v1.8.1 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/golang-lru/arc/v2 v2.0.5 // indirect github.com/hashicorp/golang-lru/v2 v2.0.5 // indirect github.com/imdario/mergo v0.3.15 // indirect @@ -130,7 +133,6 @@ require ( golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e // indirect golang.org/x/net v0.20.0 // indirect golang.org/x/oauth2 v0.16.0 // indirect - golang.org/x/sync v0.6.0 // indirect golang.org/x/sys v0.16.0 // indirect golang.org/x/term v0.16.0 // indirect golang.org/x/text v0.14.0 // indirect diff --git a/oci/go.sum b/oci/go.sum index 87ecfa5d..45aeee99 100644 --- a/oci/go.sum +++ b/oci/go.sum @@ -155,6 +155,12 @@ github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 h1:YBftPWNWd4WwGqtY2yeZL2ef8rHAxPBD8KFhJpmcqms= github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0/go.mod h1:YN5jB8ie0yfIUg6VvR9Kz84aCaG7AsGZnLjhHbUqwPg= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v0.9.2 h1:CG6TE5H9/JXsFWJCfoIVpKFIkFe6ysEuHirp4DxCsHI= +github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ= +github.com/hashicorp/go-retryablehttp v0.7.5 h1:bJj+Pj19UZMIweq/iie+1u5YCdGrnxCT9yvm0e+Nd5M= +github.com/hashicorp/go-retryablehttp v0.7.5/go.mod h1:Jy/gPYAdjqffZ/yFGCFV2doI5wjtH1ewM9u8iYVjtX8= github.com/hashicorp/golang-lru/arc/v2 v2.0.5 h1:l2zaLDubNhW4XO3LnliVj0GXO3+/CGNJAg1dcN2Fpfw= github.com/hashicorp/golang-lru/arc/v2 v2.0.5/go.mod h1:ny6zBSQZi2JxIeYcv7kt2sH2PXJtirBN7RDhRpxPkxU= github.com/hashicorp/golang-lru/v2 v2.0.5 h1:wW7h1TG88eUIJ2i69gaE3uNVtEPIagzhGvHgwfx2Vm4=