diff --git a/integration/build_local_containerd_helper_test.go b/integration/build_local_containerd_helper_test.go new file mode 100644 index 0000000000000..43c296e549eb2 --- /dev/null +++ b/integration/build_local_containerd_helper_test.go @@ -0,0 +1,209 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package integration + +import ( + "context" + "fmt" + "path/filepath" + "sync" + "testing" + + "github.com/containerd/containerd" + "github.com/containerd/containerd/content" + "github.com/containerd/containerd/leases" + "github.com/containerd/containerd/pkg/cri/constants" + "github.com/containerd/containerd/platforms" + "github.com/containerd/containerd/plugin" + "github.com/containerd/containerd/services" + ctrdsrv "github.com/containerd/containerd/services/server" + srvconfig "github.com/containerd/containerd/services/server/config" + "github.com/containerd/containerd/snapshots" + + // NOTE: Importing containerd plugin(s) to build functionality in + // client side, which means there is no need to up server. It can + // prevent interference from testing with the same image. + containersapi "github.com/containerd/containerd/api/services/containers/v1" + diffapi "github.com/containerd/containerd/api/services/diff/v1" + imagesapi "github.com/containerd/containerd/api/services/images/v1" + introspectionapi "github.com/containerd/containerd/api/services/introspection/v1" + namespacesapi "github.com/containerd/containerd/api/services/namespaces/v1" + tasksapi "github.com/containerd/containerd/api/services/tasks/v1" + _ "github.com/containerd/containerd/diff/walking/plugin" + "github.com/containerd/containerd/events/exchange" + _ "github.com/containerd/containerd/events/plugin" + _ "github.com/containerd/containerd/gc/scheduler" + _ "github.com/containerd/containerd/leases/plugin" + _ "github.com/containerd/containerd/runtime/v2" + _ "github.com/containerd/containerd/runtime/v2/runc/options" + _ "github.com/containerd/containerd/services/containers" + _ "github.com/containerd/containerd/services/content" + _ "github.com/containerd/containerd/services/diff" + _ "github.com/containerd/containerd/services/events" + _ "github.com/containerd/containerd/services/images" + _ "github.com/containerd/containerd/services/introspection" + _ "github.com/containerd/containerd/services/leases" + _ "github.com/containerd/containerd/services/namespaces" + _ "github.com/containerd/containerd/services/snapshots" + _ "github.com/containerd/containerd/services/tasks" + _ "github.com/containerd/containerd/services/version" + + "github.com/stretchr/testify/assert" +) + +var ( + loadPluginOnce sync.Once + loadedPlugins []*plugin.Registration + loadedPluginsErr error +) + +// buildLocalContainerdClient is to return containerd client with initialized +// core plugins in local. +func buildLocalContainerdClient(t *testing.T, tmpDir string) *containerd.Client { + ctx := context.Background() + + // load plugins + loadPluginOnce.Do(func() { + loadedPlugins, loadedPluginsErr = ctrdsrv.LoadPlugins(ctx, &srvconfig.Config{}) + assert.NoError(t, loadedPluginsErr) + }) + + // init plugins + var ( + // TODO: Remove this in 2.0 and let event plugin crease it + events = exchange.NewExchange() + + initialized = plugin.NewPluginSet() + + // NOTE: plugin.Set doesn't provide the way to get all the same + // type plugins. lastInitContext is used to record the last + // initContext and work with getServicesOpts. + lastInitContext *plugin.InitContext + + config = &srvconfig.Config{ + Version: 2, + Root: filepath.Join(tmpDir, "root"), + State: filepath.Join(tmpDir, "state"), + } + ) + + for _, p := range loadedPlugins { + initContext := plugin.NewContext( + ctx, + p, + initialized, + config.Root, + config.State, + ) + initContext.Events = events + + // load the plugin specific configuration if it is provided + if p.Config != nil { + pc, err := config.Decode(p) + assert.NoError(t, err) + + initContext.Config = pc + } + + result := p.Init(initContext) + assert.NoError(t, initialized.Add(result)) + + _, err := result.Instance() + assert.NoError(t, err) + + lastInitContext = initContext + } + + servicesOpts, err := getServicesOpts(lastInitContext) + assert.NoError(t, err) + + client, err := containerd.New( + "", + containerd.WithDefaultNamespace(constants.K8sContainerdNamespace), + containerd.WithDefaultPlatform(platforms.Default()), + containerd.WithServices(servicesOpts...), + ) + assert.NoError(t, err) + + return client +} + +// getServicesOpts get service options from plugin context. +// +// TODO(fuweid): It is copied from pkg/cri/cri.go. Should we make it as helper? +func getServicesOpts(ic *plugin.InitContext) ([]containerd.ServicesOpt, error) { + var opts []containerd.ServicesOpt + for t, fn := range map[plugin.Type]func(interface{}) containerd.ServicesOpt{ + plugin.EventPlugin: func(i interface{}) containerd.ServicesOpt { + return containerd.WithEventService(i.(containerd.EventService)) + }, + plugin.LeasePlugin: func(i interface{}) containerd.ServicesOpt { + return containerd.WithLeasesService(i.(leases.Manager)) + }, + } { + i, err := ic.Get(t) + if err != nil { + return nil, fmt.Errorf("failed to get %q plugin: %w", t, err) + } + opts = append(opts, fn(i)) + } + plugins, err := ic.GetByType(plugin.ServicePlugin) + if err != nil { + return nil, fmt.Errorf("failed to get service plugin: %w", err) + } + + for s, fn := range map[string]func(interface{}) containerd.ServicesOpt{ + services.ContentService: func(s interface{}) containerd.ServicesOpt { + return containerd.WithContentStore(s.(content.Store)) + }, + services.ImagesService: func(s interface{}) containerd.ServicesOpt { + return containerd.WithImageClient(s.(imagesapi.ImagesClient)) + }, + services.SnapshotsService: func(s interface{}) containerd.ServicesOpt { + return containerd.WithSnapshotters(s.(map[string]snapshots.Snapshotter)) + }, + services.ContainersService: func(s interface{}) containerd.ServicesOpt { + return containerd.WithContainerClient(s.(containersapi.ContainersClient)) + }, + services.TasksService: func(s interface{}) containerd.ServicesOpt { + return containerd.WithTaskClient(s.(tasksapi.TasksClient)) + }, + services.DiffService: func(s interface{}) containerd.ServicesOpt { + return containerd.WithDiffClient(s.(diffapi.DiffClient)) + }, + services.NamespacesService: func(s interface{}) containerd.ServicesOpt { + return containerd.WithNamespaceClient(s.(namespacesapi.NamespacesClient)) + }, + services.IntrospectionService: func(s interface{}) containerd.ServicesOpt { + return containerd.WithIntrospectionClient(s.(introspectionapi.IntrospectionClient)) + }, + } { + p := plugins[s] + if p == nil { + return nil, fmt.Errorf("service %q not found", s) + } + i, err := p.Instance() + if err != nil { + return nil, fmt.Errorf("failed to get instance of service %q: %w", s, err) + } + if i == nil { + return nil, fmt.Errorf("instance of service %q not found", s) + } + opts = append(opts, fn(i)) + } + return opts, nil +} diff --git a/integration/build_local_containerd_helper_test_linux.go b/integration/build_local_containerd_helper_test_linux.go new file mode 100644 index 0000000000000..8bd8f057c4ed5 --- /dev/null +++ b/integration/build_local_containerd_helper_test_linux.go @@ -0,0 +1,23 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package integration + +import ( + // Register for linux platforms + _ "github.com/containerd/containerd/runtime/v1/linux" + _ "github.com/containerd/containerd/snapshots/overlay/plugin" +) diff --git a/integration/image_pull_timeout_test.go b/integration/image_pull_timeout_test.go new file mode 100644 index 0000000000000..fe3f53c422177 --- /dev/null +++ b/integration/image_pull_timeout_test.go @@ -0,0 +1,445 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package integration + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "runtime" + "strings" + "sync" + "testing" + "time" + + "github.com/containerd/containerd" + "github.com/containerd/containerd/content" + "github.com/containerd/containerd/leases" + "github.com/containerd/containerd/log" + "github.com/containerd/containerd/namespaces" + criconfig "github.com/containerd/containerd/pkg/cri/config" + criserver "github.com/containerd/containerd/pkg/cri/server" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/stretchr/testify/assert" + runtimeapi "k8s.io/cri-api/pkg/apis/runtime/v1" +) + +var ( + defaultImagePullProgressTimeout = 5 * time.Second + pullProgressTestImageName = "ghcr.io/containerd/registry:2.7" +) + +func TestCRIImagePullTimeout(t *testing.T) { + t.Parallel() + + // TODO(fuweid): Test it in Windows. + if runtime.GOOS != "linux" { + t.Skip() + } + + t.Run("HoldingContentOpenWriter", testCRIImagePullTimeoutByHoldingContentOpenWriter) + t.Run("NoDataTransferred", testCRIImagePullTimeoutByNoDataTransferred) +} + +// testCRIImagePullTimeoutByHoldingContentOpenWriter tests that +// +// It should not cancel if there is no active http requests. +// +// When there are several pulling requests for the same blob content, there +// will only one active http request. It is singleflight. For the waiting pulling +// request, we should not cancel. +func testCRIImagePullTimeoutByHoldingContentOpenWriter(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + + cli := buildLocalContainerdClient(t, tmpDir) + + criService, err := initLocalCRIPlugin(cli, tmpDir, criconfig.Registry{}) + assert.NoError(t, err) + + ctx := namespaces.WithNamespace(context.Background(), k8sNamespace) + contentStore := cli.ContentStore() + + // imageIndexJSON is the manifest of ghcr.io/containerd/registry:2.7. + var imageIndexJSON = ` +{ + "manifests": [ + { + "digest": "sha256:b0b8dd398630cbb819d9a9c2fbd50561370856874b5d5d935be2e0af07c0ff4c", + "mediaType": "application/vnd.docker.distribution.manifest.v2+json", + "platform": { + "architecture": "amd64", + "os": "linux" + }, + "size": 1363 + }, + { + "digest": "sha256:6de6b4d5063876c92220d0438ae6068c778d9a2d3845b3d5c57a04a307998df6", + "mediaType": "application/vnd.docker.distribution.manifest.v2+json", + "platform": { + "architecture": "arm", + "os": "linux", + "variant": "v6" + }, + "size": 1363 + }, + { + "digest": "sha256:c11a277a91045f91866550314a988f937366bc2743859aa0f6ec8ef57b0458ce", + "mediaType": "application/vnd.docker.distribution.manifest.v2+json", + "platform": { + "architecture": "arm64", + "os": "linux", + "variant": "v8" + }, + "size": 1363 + } + ], + "mediaType": "application/vnd.docker.distribution.manifest.list.v2+json", + "schemaVersion": 2 +}` + var index ocispec.Index + assert.NoError(t, json.Unmarshal([]byte(imageIndexJSON), &index)) + + var manifestWriters = []io.Closer{} + + cleanupWriters := func() { + for _, closer := range manifestWriters { + closer.Close() + } + manifestWriters = manifestWriters[:0] + } + defer cleanupWriters() + + // hold the writer by the desc + for _, desc := range index.Manifests { + writer, err := content.OpenWriter(ctx, contentStore, + content.WithDescriptor(desc), + content.WithRef(fmt.Sprintf("manifest-%v", desc.Digest)), + ) + assert.NoError(t, err, "failed to locked manifest") + + t.Logf("locked the manifest %+v", desc) + manifestWriters = append(manifestWriters, writer) + } + + errCh := make(chan error) + go func() { + defer close(errCh) + + _, err := criService.PullImage(ctx, &runtimeapi.PullImageRequest{ + Image: &runtimeapi.ImageSpec{ + Image: pullProgressTestImageName, + }, + }) + errCh <- err + }() + + select { + case <-time.After(defaultImagePullProgressTimeout * 5): + // release the lock + cleanupWriters() + case err := <-errCh: + t.Fatalf("PullImage should not return because the manifest has been locked, but got error=%v", err) + } + assert.NoError(t, <-errCh) +} + +// testCRIImagePullTimeoutByNoDataTransferred tests that +// +// It should fail because there is no data transferred in open http request. +// +// The case uses the local mirror registry to forward request with circuit +// breaker. If the local registry has transferred a certain amount of data in +// connection, it will enable circuit breaker and sleep for a while. For the +// CRI plugin, it will see there is no data transported. And then cancel the +// pulling request when timeout. +// +// This case uses ghcr.io/containerd/registry:2.7 which has one layer > 3MB. +// The circuit breaker will enable after transferred 3MB in one connection. +func testCRIImagePullTimeoutByNoDataTransferred(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + + cli := buildLocalContainerdClient(t, tmpDir) + + mirrorSrv := newMirrorRegistryServer(mirrorRegistryServerConfig{ + limitedBytesPerConn: 1024 * 1024 * 3, // 3MB + retryAfter: 100 * time.Second, + targetURL: &url.URL{ + Scheme: "https", + Host: "ghcr.io", + }, + }) + + ts := setupLocalMirrorRegistry(mirrorSrv) + defer ts.Close() + + mirrorURL, err := url.Parse(ts.URL) + assert.NoError(t, err) + + var hostTomlContent = fmt.Sprintf(` +[host."%s"] + capabilities = ["pull", "resolve", "push"] + skip_verify = true +`, mirrorURL.String()) + + hostCfgDir := filepath.Join(tmpDir, "registrycfg", mirrorURL.Host) + assert.NoError(t, os.MkdirAll(hostCfgDir, 0600)) + + err = os.WriteFile(filepath.Join(hostCfgDir, "hosts.toml"), []byte(hostTomlContent), 0600) + assert.NoError(t, err) + + ctx := namespaces.WithNamespace(context.Background(), k8sNamespace) + for idx, registryCfg := range []criconfig.Registry{ + { + ConfigPath: filepath.Dir(hostCfgDir), + }, + // TODO(fuweid): + // + // Both Mirrors and Configs are deprecated in the future. And + // this registryCfg should also be removed at that time. + { + Mirrors: map[string]criconfig.Mirror{ + mirrorURL.Host: { + Endpoints: []string{mirrorURL.String()}, + }, + }, + Configs: map[string]criconfig.RegistryConfig{ + mirrorURL.Host: { + TLS: &criconfig.TLSConfig{ + InsecureSkipVerify: true, + }, + }, + }, + }, + } { + criService, err := initLocalCRIPlugin(cli, tmpDir, registryCfg) + assert.NoError(t, err) + + dctx, _, err := cli.WithLease(ctx) + assert.NoError(t, err) + + _, err = criService.PullImage(dctx, &runtimeapi.PullImageRequest{ + Image: &runtimeapi.ImageSpec{ + Image: fmt.Sprintf("%s/%s", mirrorURL.Host, "containerd/registry:2.7"), + }, + }) + + assert.Equal(t, errors.Unwrap(err), context.Canceled, "[%v] expected canceled error, but got (%v)", idx, err) + assert.Equal(t, mirrorSrv.limiter.clearHitCircuitBreaker(), true, "[%v] expected to hit circuit breaker", idx) + + // cleanup the temp data by sync delete + lid, ok := leases.FromContext(dctx) + assert.Equal(t, ok, true) + err = cli.LeasesService().Delete(ctx, leases.Lease{ID: lid}, leases.SynchronousDelete) + assert.NoError(t, err) + } +} + +func setupLocalMirrorRegistry(srv *mirrorRegistryServer) *httptest.Server { + return httptest.NewServer(srv) +} + +func newMirrorRegistryServer(cfg mirrorRegistryServerConfig) *mirrorRegistryServer { + return &mirrorRegistryServer{ + client: http.DefaultClient, + limiter: newIOCopyLimiter(cfg.limitedBytesPerConn, cfg.retryAfter), + targetURL: cfg.targetURL, + } +} + +type mirrorRegistryServerConfig struct { + limitedBytesPerConn int + retryAfter time.Duration + targetURL *url.URL +} + +type mirrorRegistryServer struct { + client *http.Client + limiter *ioCopyLimiter + targetURL *url.URL +} + +func (srv *mirrorRegistryServer) ServeHTTP(respW http.ResponseWriter, req *http.Request) { + originalURL := &url.URL{ + Scheme: "http", + Host: req.Host, + } + + req.URL.Host = srv.targetURL.Host + req.URL.Scheme = srv.targetURL.Scheme + req.Host = srv.targetURL.Host + + req.RequestURI = "" + fresp, err := srv.client.Do(req) + if err != nil { + http.Error(respW, fmt.Sprintf("failed to mirror request: %v", err), http.StatusBadGateway) + return + } + defer fresp.Body.Close() + + // copy header and modified that authentication value + authKey := http.CanonicalHeaderKey("WWW-Authenticate") + for key, vals := range fresp.Header { + replace := (key == authKey) + + for _, val := range vals { + if replace { + val = strings.Replace(val, srv.targetURL.String(), originalURL.String(), -1) + val = strings.Replace(val, srv.targetURL.Host, originalURL.Host, -1) + } + respW.Header().Add(key, val) + } + } + + respW.WriteHeader(fresp.StatusCode) + if err := srv.limiter.limitedCopy(req.Context(), respW, fresp.Body); err != nil { + log.G(req.Context()).Errorf("failed to forward response: %v", err) + } +} + +var ( + defaultBufSize = 1024 * 4 + + bufPool = sync.Pool{ + New: func() interface{} { + buffer := make([]byte, defaultBufSize) + return &buffer + }, + } +) + +func newIOCopyLimiter(limitedBytesPerConn int, retryAfter time.Duration) *ioCopyLimiter { + return &ioCopyLimiter{ + limitedBytes: limitedBytesPerConn, + retryAfter: retryAfter, + } +} + +// ioCopyLimiter will postpone the data transfer after limitedBytes has been +// transferred, like circuit breaker. +type ioCopyLimiter struct { + limitedBytes int + retryAfter time.Duration + hitCircuitBreaker bool +} + +func (l *ioCopyLimiter) clearHitCircuitBreaker() bool { + last := l.hitCircuitBreaker + l.hitCircuitBreaker = false + return last +} + +func (l *ioCopyLimiter) limitedCopy(ctx context.Context, dst io.Writer, src io.Reader) error { + var ( + bufRef = bufPool.Get().(*[]byte) + buf = *bufRef + timer = time.NewTimer(0) + written int64 + ) + + defer bufPool.Put(bufRef) + + stopTimer := func(t *time.Timer, needRecv bool) { + if !t.Stop() && needRecv { + <-t.C + } + } + + waitForRetry := func(t *time.Timer, delay time.Duration) error { + needRecv := true + + t.Reset(delay) + select { + case <-t.C: + needRecv = false + case <-ctx.Done(): + return ctx.Err() + } + stopTimer(t, needRecv) + return nil + } + + stopTimer(timer, true) + defer timer.Stop() + for { + if written > int64(l.limitedBytes) { + l.hitCircuitBreaker = true + + log.G(ctx).Warnf("after %v bytes transferred, enable breaker and retransfer after %v", written, l.retryAfter) + if wer := waitForRetry(timer, l.retryAfter); wer != nil { + return wer + } + + written = 0 + l.hitCircuitBreaker = false + } + + nr, er := io.ReadAtLeast(src, buf, len(buf)) + if nr > 0 { + nw, ew := dst.Write(buf[0:nr]) + if nw > 0 { + written += int64(nw) + } + if ew != nil { + return ew + } + if nr != nw { + return io.ErrShortWrite + } + } + if er != nil { + if er != io.EOF && er != io.ErrUnexpectedEOF { + return er + } + break + } + } + return nil +} + +// initLocalCRIPlugin uses containerd.Client to init CRI plugin. +// +// NOTE: We don't need to start the CRI plugin here because we just need the +// ImageService API. +func initLocalCRIPlugin(client *containerd.Client, tmpDir string, registryCfg criconfig.Registry) (criserver.CRIService, error) { + containerdRootDir := filepath.Join(tmpDir, "root") + criWorkDir := filepath.Join(tmpDir, "cri-plugin") + + cfg := criconfig.Config{ + PluginConfig: criconfig.PluginConfig{ + ContainerdConfig: criconfig.ContainerdConfig{ + Snapshotter: containerd.DefaultSnapshotter, + }, + Registry: registryCfg, + ImagePullProgressTimeout: defaultImagePullProgressTimeout.String(), + }, + ContainerdRootDir: containerdRootDir, + RootDir: filepath.Join(criWorkDir, "root"), + StateDir: filepath.Join(criWorkDir, "state"), + } + return criserver.NewCRIService(cfg, client) +} diff --git a/pkg/cri/config/config.go b/pkg/cri/config/config.go index f0dbcf0efba46..13b6150c3c3fc 100644 --- a/pkg/cri/config/config.go +++ b/pkg/cri/config/config.go @@ -313,6 +313,14 @@ type PluginConfig struct { EnableCDI bool `toml:"enable_cdi" json:"enableCDI"` // CDISpecDirs is the list of directories to scan for Container Device Interface Specifications CDISpecDirs []string `toml:"cdi_spec_dirs" json:"cdiSpecDirs"` + // ImagePullProgressTimeout is the maximum duration that there is no + // image data read from image registry in the open connection. It will + // be reset whatever a new byte has been read. If timeout, the image + // pulling will be cancelled. A zero value means there is no timeout. + // + // The string is in the golang duration format, see: + // https://golang.org/pkg/time/#ParseDuration + ImagePullProgressTimeout string `toml:"image_pull_progress_timeout" json:"imagePullProgressTimeout"` } // X509KeyPairStreaming contains the x509 configuration for streaming @@ -459,5 +467,12 @@ func ValidatePluginConfig(ctx context.Context, c *PluginConfig) error { return fmt.Errorf("invalid stream idle timeout: %w", err) } } + + // Validation for image_pull_progress_timeout + if c.ImagePullProgressTimeout != "" { + if _, err := time.ParseDuration(c.ImagePullProgressTimeout); err != nil { + return fmt.Errorf("invalid image pull progress timeout: %w", err) + } + } return nil } diff --git a/pkg/cri/config/config_unix.go b/pkg/cri/config/config_unix.go index 19463b492208b..7eacdf7921cc7 100644 --- a/pkg/cri/config/config_unix.go +++ b/pkg/cri/config/config_unix.go @@ -20,6 +20,8 @@ package config import ( + "time" + "github.com/containerd/containerd" "github.com/containerd/containerd/pkg/cri/streaming" "github.com/pelletier/go-toml" @@ -104,7 +106,8 @@ func DefaultConfig() PluginConfig { ImageDecryption: ImageDecryption{ KeyModel: KeyModelNode, }, - EnableCDI: false, - CDISpecDirs: []string{"/etc/cdi", "/var/run/cdi"}, + EnableCDI: false, + CDISpecDirs: []string{"/etc/cdi", "/var/run/cdi"}, + ImagePullProgressTimeout: time.Minute.String(), } } diff --git a/pkg/cri/config/config_windows.go b/pkg/cri/config/config_windows.go index dd1eb209f90ad..eb60dbe0dfca1 100644 --- a/pkg/cri/config/config_windows.go +++ b/pkg/cri/config/config_windows.go @@ -19,6 +19,7 @@ package config import ( "os" "path/filepath" + "time" "github.com/containerd/containerd" "github.com/containerd/containerd/pkg/cri/streaming" @@ -62,5 +63,6 @@ func DefaultConfig() PluginConfig { ImageDecryption: ImageDecryption{ KeyModel: KeyModelNode, }, + ImagePullProgressTimeout: time.Minute.String(), } } diff --git a/pkg/cri/server/image_pull.go b/pkg/cri/server/image_pull.go index 46c3ffb9ea6e0..1990090530327 100644 --- a/pkg/cri/server/image_pull.go +++ b/pkg/cri/server/image_pull.go @@ -22,12 +22,15 @@ import ( "crypto/x509" "encoding/base64" "fmt" + "io" "net" "net/http" "net/url" "os" "path/filepath" "strings" + "sync" + "sync/atomic" "time" "github.com/containerd/containerd" @@ -98,10 +101,20 @@ func (c *criService) PullImage(ctx context.Context, r *runtime.PullImageRequest) if ref != imageRef { log.G(ctx).Debugf("PullImage using normalized image ref: %q", ref) } + + imagePullProgressTimeout, err := time.ParseDuration(c.config.ImagePullProgressTimeout) + if err != nil { + return nil, fmt.Errorf("failed to parse image_pull_progress_timeout %q: %w", c.config.ImagePullProgressTimeout, err) + } + var ( + pctx, pcancel = context.WithCancel(ctx) + + pullReporter = newPullProgressReporter(ref, pcancel, imagePullProgressTimeout) + resolver = docker.NewResolver(docker.ResolverOptions{ Headers: c.config.Registry.Headers, - Hosts: c.registryHosts(ctx, r.GetAuth()), + Hosts: c.registryHosts(ctx, r.GetAuth(), pullReporter.optionUpdateClient), }) isSchema1 bool imageHandler containerdimages.HandlerFunc = func(_ context.Context, @@ -138,7 +151,9 @@ func (c *criService) PullImage(ctx context.Context, r *runtime.PullImageRequest) containerd.WithChildLabelMap(containerdimages.ChildGCLabelsFilterLayers)) } - image, err := c.client.Pull(ctx, ref, pullOpts...) + pullReporter.start(pctx) + image, err := c.client.Pull(pctx, ref, pullOpts...) + pcancel() if err != nil { return nil, fmt.Errorf("failed to pull and unpack image %q: %w", ref, err) } @@ -332,10 +347,12 @@ func hostDirFromRoots(roots []string) func(string) (string, error) { } // registryHosts is the registry hosts to be used by the resolver. -func (c *criService) registryHosts(ctx context.Context, auth *runtime.AuthConfig) docker.RegistryHosts { +func (c *criService) registryHosts(ctx context.Context, auth *runtime.AuthConfig, updateClientFn config.UpdateClientFunc) docker.RegistryHosts { paths := filepath.SplitList(c.config.Registry.ConfigPath) if len(paths) > 0 { - hostOptions := config.HostOptions{} + hostOptions := config.HostOptions{ + UpdateClient: updateClientFn, + } hostOptions.Credentials = func(host string) (string, string, error) { hostauth := auth if hostauth == nil { @@ -388,6 +405,13 @@ func (c *criService) registryHosts(ctx context.Context, auth *runtime.AuthConfig if auth == nil && config.Auth != nil { auth = toRuntimeAuthConfig(*config.Auth) } + + if updateClientFn != nil { + if err := updateClientFn(client); err != nil { + return nil, fmt.Errorf("failed to update http client: %w", err) + } + } + authorizer := docker.NewDockerAuthorizer( docker.WithAuthClient(client), docker.WithAuthCreds(func(host string) (string, string, error) { @@ -579,3 +603,186 @@ func getLayers(ctx context.Context, key string, descs []imagespec.Descriptor, va } return } + +const ( + // minPullProgressReportInternal is used to prevent the reporter from + // eating more CPU resources + minPullProgressReportInternal = 5 * time.Second + // defaultPullProgressReportInterval represents that how often the + // reporter checks that pull progress. + defaultPullProgressReportInterval = 10 * time.Second +) + +// pullProgressReporter is used to check single PullImage progress. +type pullProgressReporter struct { + ref string + cancel context.CancelFunc + reqReporter pullRequestReporter + timeout time.Duration +} + +func newPullProgressReporter(ref string, cancel context.CancelFunc, timeout time.Duration) *pullProgressReporter { + return &pullProgressReporter{ + ref: ref, + cancel: cancel, + reqReporter: pullRequestReporter{}, + timeout: timeout, + } +} + +func (reporter *pullProgressReporter) optionUpdateClient(client *http.Client) error { + client.Transport = &pullRequestReporterRoundTripper{ + rt: client.Transport, + reqReporter: &reporter.reqReporter, + } + return nil +} + +func (reporter *pullProgressReporter) start(ctx context.Context) { + if reporter.timeout == 0 { + log.G(ctx).Infof("no timeout and will not start pulling image %s reporter", reporter.ref) + return + } + + go func() { + var ( + reportInterval = defaultPullProgressReportInterval + + lastSeenBytesRead = uint64(0) + lastSeenTimestamp = time.Now() + ) + + // check progress more frequently if timeout < default internal + if reporter.timeout < reportInterval { + reportInterval = reporter.timeout / 2 + + if reportInterval < minPullProgressReportInternal { + reportInterval = minPullProgressReportInternal + } + } + + var ticker = time.NewTicker(reportInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + activeReqs, bytesRead := reporter.reqReporter.status() + + log.G(ctx).WithField("ref", reporter.ref). + WithField("activeReqs", activeReqs). + WithField("totalBytesRead", bytesRead). + WithField("lastSeenBytesRead", lastSeenBytesRead). + WithField("lastSeenTimestamp", lastSeenTimestamp). + WithField("reportInterval", reportInterval). + Tracef("progress for image pull") + + if activeReqs == 0 || bytesRead > lastSeenBytesRead { + lastSeenBytesRead = bytesRead + lastSeenTimestamp = time.Now() + continue + } + + if time.Since(lastSeenTimestamp) > reporter.timeout { + log.G(ctx).Errorf("cancel pulling image %s because of no progress in %v", reporter.ref, reporter.timeout) + reporter.cancel() + return + } + case <-ctx.Done(): + activeReqs, bytesRead := reporter.reqReporter.status() + log.G(ctx).Infof("stop pulling image %s: active requests=%v, bytes read=%v", reporter.ref, activeReqs, bytesRead) + return + } + } + }() +} + +// countingReadCloser wraps http.Response.Body with pull request reporter, +// which is used by pullRequestReporterRoundTripper. +type countingReadCloser struct { + once sync.Once + + rc io.ReadCloser + reqReporter *pullRequestReporter +} + +// Read reads bytes from original io.ReadCloser and increases bytes in +// pull request reporter. +func (r *countingReadCloser) Read(p []byte) (int, error) { + n, err := r.rc.Read(p) + r.reqReporter.incByteRead(uint64(n)) + return n, err +} + +// Close closes the original io.ReadCloser and only decreases the number of +// active pull requests once. +func (r *countingReadCloser) Close() error { + err := r.rc.Close() + r.once.Do(r.reqReporter.decRequest) + return err +} + +// pullRequestReporter is used to track the progress per each criapi.PullImage. +type pullRequestReporter struct { + // activeReqs indicates that current number of active pulling requests, + // including auth requests. + activeReqs int32 + // totalBytesRead indicates that the total bytes has been read from + // remote registry. + totalBytesRead uint64 +} + +func (reporter *pullRequestReporter) incRequest() { + atomic.AddInt32(&reporter.activeReqs, 1) +} + +func (reporter *pullRequestReporter) decRequest() { + atomic.AddInt32(&reporter.activeReqs, -1) +} + +func (reporter *pullRequestReporter) incByteRead(nr uint64) { + atomic.AddUint64(&reporter.totalBytesRead, nr) +} + +func (reporter *pullRequestReporter) status() (currentReqs int32, totalBytesRead uint64) { + currentReqs = atomic.LoadInt32(&reporter.activeReqs) + totalBytesRead = atomic.LoadUint64(&reporter.totalBytesRead) + return currentReqs, totalBytesRead +} + +// pullRequestReporterRoundTripper wraps http.RoundTripper with pull request +// reporter which is used to track the progress of active http request with +// counting readable http.Response.Body. +// +// NOTE: +// +// Although containerd provides ingester manager to track the progress +// of pulling request, for example `ctr image pull` shows the console progress +// bar, it needs more CPU resources to open/read the ingested files with +// acquiring containerd metadata plugin's boltdb lock. +// +// Before sending HTTP request to registry, the containerd.Client.Pull library +// will open writer by containerd ingester manager. Based on this, the +// http.RoundTripper wrapper can track the active progress with lower overhead +// even if the ref has been locked in ingester manager by other Pull request. +type pullRequestReporterRoundTripper struct { + rt http.RoundTripper + + reqReporter *pullRequestReporter +} + +func (rt *pullRequestReporterRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + rt.reqReporter.incRequest() + + resp, err := rt.rt.RoundTrip(req) + if err != nil { + rt.reqReporter.decRequest() + return nil, err + } + + resp.Body = &countingReadCloser{ + rc: resp.Body, + reqReporter: rt.reqReporter, + } + return resp, err +}