diff --git a/libcontainer/cgroups/file.go b/libcontainer/cgroups/file.go index b1888f7ede6..aaf0ef1c978 100644 --- a/libcontainer/cgroups/file.go +++ b/libcontainer/cgroups/file.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "os" + "path" "strings" "sync" @@ -13,7 +14,11 @@ import ( ) // OpenFile opens a cgroup file in a given dir with given flags. -// It is supposed to be used for cgroup files only. +// It is supposed to be used for cgroup files only, and returns +// an error is the file is not a cgroup file. +// +// Arguments dir and file are joined together to form an absolute path +// to a file being opened. func OpenFile(dir, file string, flags int) (*os.File, error) { if dir == "" { return nil, fmt.Errorf("no directory specified for %s", file) @@ -109,8 +114,6 @@ func prepareOpenat2() error { return prepErr } -// OpenFile opens a cgroup file in a given dir with given flags. -// It is supposed to be used for cgroup files only. func openFile(dir, file string, flags int) (*os.File, error) { mode := os.FileMode(0) if TestMode && flags&os.O_WRONLY != 0 { @@ -118,34 +121,36 @@ func openFile(dir, file string, flags int) (*os.File, error) { flags |= os.O_TRUNC | os.O_CREATE mode = 0o600 } + path := path.Join(dir, file) if prepareOpenat2() != nil { - return openFallback(dir, file, flags, mode) + return openFallback(path, flags, mode) } - reldir := strings.TrimPrefix(dir, cgroupfsPrefix) - if len(reldir) == len(dir) { // non-standard path, old system? - return openFallback(dir, file, flags, mode) + relPath := strings.TrimPrefix(path, cgroupfsPrefix) + if len(relPath) == len(path) { // non-standard path, old system? + return openFallback(path, flags, mode) } - relname := reldir + "/" + file - fd, err := unix.Openat2(cgroupFd, relname, + fd, err := unix.Openat2(cgroupFd, relPath, &unix.OpenHow{ Resolve: resolveFlags, Flags: uint64(flags) | unix.O_CLOEXEC, Mode: uint64(mode), }) if err != nil { - return nil, &os.PathError{Op: "openat2", Path: dir + "/" + file, Err: err} + return nil, &os.PathError{Op: "openat2", Path: path, Err: err} } - return os.NewFile(uintptr(fd), cgroupfsPrefix+relname), nil + return os.NewFile(uintptr(fd), relPath), nil } var errNotCgroupfs = errors.New("not a cgroup file") -// openFallback is used when openat2(2) is not available. It checks the opened +// Can be changed by unit tests. +var openFallback = openAndCheck + +// openAndCheck is used when openat2(2) is not available. It checks the opened // file is on cgroupfs, returning an error otherwise. -func openFallback(dir, file string, flags int, mode os.FileMode) (*os.File, error) { - path := dir + "/" + file +func openAndCheck(path string, flags int, mode os.FileMode) (*os.File, error) { fd, err := os.OpenFile(path, flags, mode) if err != nil { return nil, err diff --git a/libcontainer/cgroups/file_test.go b/libcontainer/cgroups/file_test.go index 4b2cb895007..9f7be0b885a 100644 --- a/libcontainer/cgroups/file_test.go +++ b/libcontainer/cgroups/file_test.go @@ -3,6 +3,7 @@ package cgroups import ( + "errors" "fmt" "os" "path/filepath" @@ -40,3 +41,35 @@ func TestWriteCgroupFileHandlesInterrupt(t *testing.T) { } } } + +func TestOpenat2(t *testing.T) { + if !IsCgroup2UnifiedMode() { + // The reason is many test cases below test opening files from + // the top-level directory, where cgroup v1 has no files. + t.Skip("test requires cgroup v2") + } + + // Make sure we test openat2, not its fallback. + openFallback = func(_ string, _ int, _ os.FileMode) (*os.File, error) { + return nil, errors.New("fallback") + } + defer func() { openFallback = openAndCheck }() + + for _, tc := range []struct{ dir, file string }{ + {"/sys/fs/cgroup", "cgroup.controllers"}, + {"/sys/fs/cgroup", "/cgroup.controllers"}, + {"/sys/fs/cgroup/", "cgroup.controllers"}, + {"/sys/fs/cgroup/", "/cgroup.controllers"}, + {"/sys/fs/cgroup/user.slice", "cgroup.controllers"}, + {"/sys/fs/cgroup/user.slice/", "/cgroup.controllers"}, + {"/", "/sys/fs/cgroup/cgroup.controllers"}, + {"/", "sys/fs/cgroup/cgroup.controllers"}, + {"/sys/fs/cgroup/cgroup.controllers", ""}, + } { + fd, err := OpenFile(tc.dir, tc.file, os.O_RDONLY) + if err != nil { + t.Errorf("case %+v: %v", tc, err) + } + fd.Close() + } +}