diff --git a/internal/os/unix/dirfd_linux.go b/internal/os/unix/dirfd_linux.go new file mode 100644 index 000000000000..144bb1d4703b --- /dev/null +++ b/internal/os/unix/dirfd_linux.go @@ -0,0 +1,124 @@ +/* +Copyright 2024 k0s authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package unix + +import ( + "cmp" + "os" + "sync/atomic" + "syscall" + + "golang.org/x/sys/unix" +) + +// An open Linux-native handle to some path on the file system. +type LinuxPath interface { + Path + + // Stats this path using the fstatat(path, "", AT_EMPTY_PATH) syscall. + StatSelf() (*FileInfo, error) +} + +var _ LinuxPath = (*PathFD)(nil) + +// Stats this path using the fstatat(path, "", AT_EMPTY_PATH) syscall. +func (p *PathFD) StatSelf() (*FileInfo, error) { + return p.UnwrapDir().StatSelf() +} + +var _ LinuxPath = (*DirFD)(nil) + +// Stats this path using the fstatat(path, "", AT_EMPTY_PATH) syscall. +func (d *DirFD) StatSelf() (*FileInfo, error) { + return d.StatAt("", unix.AT_EMPTY_PATH) +} + +// Opens the path with the given name. +// The path is opened relative to the receiver, using the openat2 syscall. +// +// Note that, in contrast to [os.Open] and [os.OpenFile], the returned +// descriptor is not put into non-blocking mode automatically. Callers may +// decide if they want this by setting the [syscall.O_NONBLOCK] flag. +// +// Available since Linux 5.6 (April 2020). +// +// https://www.man7.org/linux/man-pages/man2/openat2.2.html +// https://git.kernel.org/pub/scm/linux/kernel/git/stable/linux.git/commit/?id=fddb5d430ad9fa91b49b1d34d0202ffe2fa0e179 +func (d *DirFD) Open2(name string, how unix.OpenHow) (*PathFD, error) { + var opened int + if err := openAt2Support.guard(func() error { + return syscallControl(d, func(fd uintptr) (err error) { + how.Flags |= unix.O_CLOEXEC + opened, err = unix.Openat2(int(fd), name, &how) + if err == nil { + return nil + } + return &os.PathError{Op: "openat2", Path: name, Err: err} + }) + }); err != nil { + return nil, err + } + + return (*PathFD)(os.NewFile(uintptr(opened), name)), nil +} + +// Opens the directory with the given name by using the openat2 syscall. +// +// See [DirFD.Open2]. +func (d *DirFD) OpenDir2(name string, how unix.OpenHow) (*DirFD, error) { + how.Flags |= unix.O_DIRECTORY + f, err := d.Open2(name, how) + return f.UnwrapDir(), err +} + +var openAt2Support = runtimeSupport{test: func() error { + // Try to open the current working directory without requiring any + // permissions (O_PATH). If that fails, assume that openat2 is unusable. + var cwd int = unix.AT_FDCWD + fd, err := unix.Openat2(cwd, ".", &unix.OpenHow{Flags: unix.O_PATH | unix.O_CLOEXEC}) + if err != nil { + return &os.SyscallError{Syscall: "openat2", Err: syscall.ENOSYS} + } + _ = unix.Close(fd) + return nil +}} + +type runtimeSupport struct { + test func() error + err atomic.Pointer[error] +} + +func (t *runtimeSupport) guard(f func() error) error { + if err := t.err.Load(); err != nil { + if *err == nil { + return f() + } + return *err + } + + err := f() + if err == nil { + t.err.Swap(&err) + return nil + } + + testErr := t.test() + if !t.err.CompareAndSwap(nil, &testErr) { + testErr = *t.err.Load() + } + return cmp.Or(testErr, err) +} diff --git a/internal/os/unix/dirfd_linux_test.go b/internal/os/unix/dirfd_linux_test.go new file mode 100644 index 000000000000..024bb9c00d06 --- /dev/null +++ b/internal/os/unix/dirfd_linux_test.go @@ -0,0 +1,47 @@ +/* +Copyright 2024 k0s authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package unix_test + +import ( + "testing" + + osunix "github.com/k0sproject/k0s/internal/os/unix" + "golang.org/x/sys/unix" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPathFD_StatSelf(t *testing.T) { + dirPath := t.TempDir() + + p, err := osunix.OpenDir(dirPath, unix.O_PATH) + require.NoError(t, err) + t.Cleanup(func() { assert.NoError(t, p.Close()) }) + + // An O_PATH descriptor cannot read anything. + _, err = p.Readdirnames(1) + assert.ErrorIs(t, err, unix.EBADF) + + // Verify that the fstatat syscall works for O_PATH file descriptors. + // It's not documented in the Linux man pages, just fstat is. + // See open(2). + stat, err := p.StatSelf() + if assert.NoError(t, err) { + assert.True(t, stat.IsDir()) + } +} diff --git a/internal/os/unix/dirfd_unix.go b/internal/os/unix/dirfd_unix.go new file mode 100644 index 000000000000..5f98212143b2 --- /dev/null +++ b/internal/os/unix/dirfd_unix.go @@ -0,0 +1,353 @@ +//go:build unix + +/* +Copyright 2024 k0s authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package unix + +import ( + "errors" + "io" + "iter" + "os" + "path/filepath" + "syscall" + "time" + + "golang.org/x/sys/unix" +) + +// An open handle to some path on the file system. +type Path interface { + io.Closer + syscall.Conn + Name() string // Delegates to [os.File.Name]. + Stat() (os.FileInfo, error) // Delegates to [os.File.Stat]. +} + +// A file descriptor pointing to a path. +// It is unspecified if that descriptor is referring to a file or a directory. +type PathFD os.File + +// The interface that [PathFD] is about to implement. +var _ Path = (*PathFD)(nil) + +// Delegates to [os.File.Close]. +func (p *PathFD) Close() error { return (*os.File)(p).Close() } + +// Delegates to [os.File.Name]. +func (p *PathFD) Name() string { return (*os.File)(p).Name() } + +// Delegates to [os.File.Stat]. +func (p *PathFD) Stat() (os.FileInfo, error) { return (*os.File)(p).Stat() } + +// Delegates to [os.File.SyscallConn]. +func (p *PathFD) SyscallConn() (syscall.RawConn, error) { return (*os.File)(p).SyscallConn() } + +// Converts this pointer to an [*os.File] without any additional checks. +// +// Note that both [os.File.ReadDir] and [os.File.Readdir] will NOT work if this +// pointer has been opened via a [DirFD] pointer. +// See [DirFD.Readdirnames] for details. +func (f *PathFD) UnwrapFile() *os.File { return (*os.File)(f) } + +// Converts this pointer to a [*DirFD] without any additional checks. +func (f *PathFD) UnwrapDir() *DirFD { return (*DirFD)(f) } + +// A file descriptor pointing to a directory (a.k.a. dirfd). It uses the +// syscalls that accept a dirfd, i.e. openat, fstatat ... +// +// Using a dirfd, as opposed to using a path (or path prefix) for all +// operations, offers some unique features: Operations are more consistent. A +// dirfd ensures that all operations are relative to the same directory +// instance. If the directory is renamed or moved, the dirfd remains valid and +// operations continue to work as expected, which is not the case when using +// paths. Using a dirfd can also be more secure. If a directory path is given as +// a string and used repeatedly, there's a risk that the path could be +// maliciously altered (e.g., through symbolic link attacks). Using a dirfd +// ensures that operations use the original directory, mitigating this type of +// attack. +type DirFD os.File + +// The interface that [DirFD] is about to implement. +var _ Path = (*DirFD)(nil) + +// Opens a [DirFD] referring to the given path. +// +// Note that this is not a chroot: The *at syscalls will only use dirfd to +// resolve relative paths, and will happily follow symlinks and cross mount +// points. +func OpenDir(path string, flags int) (*DirFD, error) { + // Use the raw syscall instead of os.OpenFile here, as the latter tries to + // put the fds into non-blocking mode. + fd, err := syscall.Open(path, flags|syscall.O_DIRECTORY|syscall.O_CLOEXEC, 0) + if err != nil { + return nil, &os.PathError{Op: "open", Path: path, Err: err} + } + + return (*DirFD)(os.NewFile(uintptr(fd), path)), nil +} + +// Delegates to [os.File.Close]. +func (d *DirFD) Close() error { return (*os.File)(d).Close() } + +// Delegates to [os.File.SyscallConn]. +func (d *DirFD) SyscallConn() (syscall.RawConn, error) { return (*os.File)(d).SyscallConn() } + +// Delegates to [os.File.Name]. +func (d *DirFD) Name() string { return (*os.File)(d).Name() } + +// Delegates to [io.File.Stat]. +func (d *DirFD) Stat() (os.FileInfo, error) { return (*os.File)(d).Stat() } + +// Opens the path with the given name. +// The path is opened relative to the receiver, using the openat syscall. +// +// Note that, in contrast to [os.Open] and [os.OpenFile], the returned +// descriptor is not put into non-blocking mode automatically. Callers may +// decide if they want this by setting the [unix.O_NONBLOCK] flag. +// +// https://www.man7.org/linux/man-pages/man2/open.2.html +func (d *DirFD) Open(name string, flags int, mode os.FileMode) (*PathFD, error) { + var opened int + err := syscallControl(d, func(fd uintptr) error { + flags, mode, err := sysOpenFlags(flags, mode) + if err != nil { + return &os.PathError{Op: "openat", Path: name, Err: err} + } + + opened, err = unix.Openat(int(fd), name, flags, mode) + if err != nil { + return &os.PathError{Op: "openat", Path: name, Err: err} + } + + return nil + }) + if err != nil { + return nil, err + } + + return (*PathFD)(os.NewFile(uintptr(opened), name)), nil +} + +// Opens the directory with the given name. +// The name is opened relative to the receiver, using the openat syscall. +func (d *DirFD) OpenDir(name string, flags int) (*DirFD, error) { + f, err := d.Open(name, flags|unix.O_DIRECTORY, 0) + return f.UnwrapDir(), err +} + +// Stats the path with the given name. +// The name is interpreted relative to the receiver, using the fstatat syscall. +// +// https://www.man7.org/linux/man-pages/man2/stat.2.html +func (d *DirFD) StatAt(name string, flags int) (*FileInfo, error) { + info := FileInfo{Path: name} + if err := syscallControl(d, func(fd uintptr) error { + if err := unix.Fstatat(int(fd), name, (*unix.Stat_t)(&info.Stat), flags); err != nil { + return &os.PathError{Op: "fstatat", Path: name, Err: err} + } + + return nil + }); err != nil { + return nil, err + } + + return &info, nil +} + +// Creates a new directory, just as [os.Mkdir] does. +// The directory is created relative to the receiver, using the mkdirat syscall. +// +// https://www.man7.org/linux/man-pages/man2/mkdir.2.html +func (d *DirFD) Mkdir(name string, mode os.FileMode) error { + return syscallControl(d, func(fd uintptr) error { + if err := unix.Mkdirat(int(fd), name, toSysMode(mode)); err != nil { + return &os.PathError{Op: "mkdirat", Path: name, Err: err} + } + return nil + }) +} + +// Remove the name and possibly the file it refers to. +// The name is removed relative to the receiver, using the unlinkat syscall. +// +// https://www.man7.org/linux/man-pages/man2/unlink.2.html +func (d *DirFD) Remove(name string) error { + return d.unlink(name, 0) +} + +// Remove the directory with the given name using the unlinkat syscall. +// The name is removed relative to the receiver, using the unlinkat syscall. +// +// https://www.man7.org/linux/man-pages/man2/unlink.2.html +func (d *DirFD) RemoveDir(name string) error { + return d.unlink(name, unix.AT_REMOVEDIR) +} + +func (d *DirFD) unlink(name string, flags int) error { + return syscallControl(d, func(fd uintptr) error { + err := unix.Unlinkat(int(fd), name, flags) + if err != nil { + return &os.PathError{Op: "unlinkat", Path: name, Err: err} + } + + return nil + }) +} + +// Delegates to [os.File.Readdirnames]. +// +// This is the only "safe" option. Both [os.File.ReadDir] and [os.File.Readdir] +// will NOT work because of the way the standard library handles directory +// entries: Both methods may end up using the lstat syscall to stat the +// directory entry pathnames under certain circumstances, which violates the +// assumptions of DirFD, and at best will produce runtime errors or return false +// data, or worse. Possible workarounds would be either to use +// [os.File.Readdirnames] internally and do an fstatat syscall for each of the +// returned pathnames (with a significant performance penalty), or to +// reimplement substantial OS-dependent parts of the standard library's internal +// dir entry handling (which feels like the "nuclear option"). For this reason, +// DirFD cannot simply implement [fs.FS], since the stat-like information should +// also be queryable in the [fs.DirEntry] interface. +func (d *DirFD) Readdirnames(n int) ([]string, error) { + return (*os.File)(d).Readdirnames(n) +} + +// Iterates over all the directory entries, returning their names, in no +// particular order. +func (d *DirFD) ReadEntryNames() iter.Seq2[string, error] { + return func(yield func(string, error) bool) { + for { + // Using n=1 is required in order to be able + // to resume iteration after early breaks. + names, err := d.Readdirnames(1) + var eof bool + if err != nil { + if !errors.Is(err, io.EOF) { + yield("", err) + return + } + eof = true + } + + for _, name := range names { + if !yield(name, nil) { + return + } + } + + if eof { + return + } + } + } +} + +type Stat unix.Stat_t + +func (s *Stat) ToFileMode() os.FileMode { return toFileMode(s.Mode) } +func (s *Stat) IsDir() bool { return s.Mode&unix.S_IFMT == unix.S_IFDIR } +func (s *Stat) ModTime() time.Time { return time.Unix(s.Mtim.Unix()) } +func (s *Stat) Sys() any { return (*unix.Stat_t)(s) } + +type FileInfo struct { + Path string + Stat +} + +var _ os.FileInfo = (*FileInfo)(nil) + +func (i *FileInfo) Name() string { return filepath.Base(i.Path) } +func (i *FileInfo) Size() int64 { return i.Stat.Size } +func (i *FileInfo) Mode() os.FileMode { return i.ToFileMode() } + +func toFileMode[T ~uint16 | ~uint32](unixMode T) os.FileMode { + fileMode := os.FileMode(unixMode) & os.ModePerm + + // https://www.man7.org/linux/man-pages/man2/fstatat.2.html#EXAMPLES + + switch unixMode & unix.S_IFMT { + case unix.S_IFREG: // regular file + // nothing to do + case unix.S_IFDIR: // directory + fileMode |= os.ModeDir + case unix.S_IFIFO: // FIFO/pipe + fileMode |= os.ModeNamedPipe + case unix.S_IFLNK: // symlink + fileMode |= os.ModeSymlink + case unix.S_IFSOCK: // socket + fileMode |= os.ModeSocket + case unix.S_IFCHR: // character device + fileMode |= os.ModeCharDevice + fallthrough + case unix.S_IFBLK: // block device + fileMode |= os.ModeDevice + default: // unknown? + fileMode |= os.ModeIrregular + } + + if unixMode&unix.S_ISGID != 0 { + fileMode |= os.ModeSetgid + } + if unixMode&unix.S_ISUID != 0 { + fileMode |= os.ModeSetuid + } + if unixMode&unix.S_ISVTX != 0 { + fileMode |= os.ModeSticky + } + + return fileMode +} + +func sysOpenFlags(flags int, mode os.FileMode) (int, uint32, error) { + const mask = os.ModePerm | os.ModeSetuid | os.ModeSetgid | os.ModeSticky + if mode != (mode & mask) { + return 0, 0, errors.New("invalid mode bits") + } + if mode != 0 && flags|os.O_CREATE == 0 { + return 0, 0, errors.New("mode may only be used when creating") + } + + return flags | syscall.O_CLOEXEC, toSysMode(mode), nil +} + +func toSysMode(mode os.FileMode) uint32 { + sysMode := uint32(mode & os.ModePerm) + if mode&os.ModeSetuid != 0 { + sysMode |= syscall.S_ISUID + } + if mode&os.ModeSetgid != 0 { + sysMode |= syscall.S_ISGID + } + if mode&os.ModeSticky != 0 { + sysMode |= syscall.S_ISVTX + } + return sysMode +} + +func syscallControl[C syscall.Conn](conn C, f func(fd uintptr) error) error { + rawConn, err := conn.SyscallConn() + if err != nil { + return err + } + + outerErr := rawConn.Control(func(fd uintptr) { err = f(fd) }) + if outerErr != nil { + return outerErr + } + return err +} diff --git a/internal/os/unix/dirfd_unix_test.go b/internal/os/unix/dirfd_unix_test.go new file mode 100644 index 000000000000..f16dbb669c16 --- /dev/null +++ b/internal/os/unix/dirfd_unix_test.go @@ -0,0 +1,212 @@ +//go:build unix + +/* +Copyright 2024 k0s authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package unix_test + +import ( + "io" + "iter" + "os" + "path/filepath" + "strconv" + "sync" + "syscall" + "testing" + "time" + + osunix "github.com/k0sproject/k0s/internal/os/unix" + "golang.org/x/sys/unix" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDirFD_NotExist(t *testing.T) { + path := filepath.Join(t.TempDir(), "foo") + + d, err := osunix.OpenDir(path, 0) + if err == nil { + assert.NoError(t, d.Close()) + } + assert.ErrorIs(t, err, os.ErrNotExist) +} + +func TestDirFD_Empty(t *testing.T) { + path := t.TempDir() + + d, err := osunix.OpenDir(path, 0) + require.NoError(t, err) + t.Cleanup(func() { assert.NoError(t, d.Close()) }) + + foo := "foo" + assertENOENT := func(t *testing.T, op string, err error) { + var pathErr *os.PathError + if assert.ErrorAs(t, err, &pathErr) { + assert.Equal(t, op, pathErr.Op) + assert.Equal(t, foo, pathErr.Path) + assert.Equal(t, syscall.ENOENT, pathErr.Err) + } + } + + _, err = d.OpenDir(foo, 0) + assertENOENT(t, "openat", err) + + _, err = d.StatAt(foo, 0) + assertENOENT(t, "fstatat", err) + + err = d.Remove(foo) + assertENOENT(t, "unlinkat", err) + + err = d.RemoveDir(foo) + assertENOENT(t, "unlinkat", err) + + if entries, err := d.Readdirnames(1); assert.Equal(t, io.EOF, err) { + assert.Empty(t, entries) + } +} + +func TestDirFD_Mkdir(t *testing.T) { + path := t.TempDir() + + d, err := osunix.OpenDir(path, 0) + require.NoError(t, err) + t.Cleanup(func() { assert.NoError(t, d.Close()) }) + + require.NoError(t, os.WriteFile(filepath.Join(path, "foo"), []byte("lorem"), 0644)) + require.NoError(t, os.Mkdir(filepath.Join(path, "bar"), 0755)) + + err = d.Mkdir("foo", 0755) + if pathErr := (*os.PathError)(nil); assert.ErrorAs(t, err, &pathErr) { + assert.Equal(t, "mkdirat", pathErr.Op) + assert.Equal(t, "foo", pathErr.Path) + assert.Equal(t, syscall.EEXIST, pathErr.Err) + } + + err = d.Mkdir("bar", 0755) + if pathErr := (*os.PathError)(nil); assert.ErrorAs(t, err, &pathErr) { + assert.Equal(t, "mkdirat", pathErr.Op) + assert.Equal(t, "bar", pathErr.Path) + assert.Equal(t, syscall.EEXIST, pathErr.Err) + } + + err = d.Mkdir("baz", 0755) + if assert.NoError(t, err) { + stat, err := os.Stat(filepath.Join(path, "baz")) + if assert.NoError(t, err) { + assert.Equal(t, os.FileMode(0755), stat.Mode()&os.ModePerm) + assert.True(t, stat.IsDir()) + } + } +} + +func TestDirFD_Filled(t *testing.T) { + dirPath := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dirPath, "foo"), []byte("lorem"), 0644)) + require.NoError(t, os.Mkdir(filepath.Join(dirPath, "bar"), 0755)) + require.NoError(t, os.WriteFile(filepath.Join(dirPath, "bar", "baz"), []byte("ipsum"), 0400)) + + now := time.Now() + require.NoError(t, os.Chtimes(filepath.Join(dirPath, "foo"), time.Time{}, now.Add(-3*time.Minute))) + require.NoError(t, os.Chtimes(filepath.Join(dirPath, "bar", "baz"), time.Time{}, now.Add(-2*time.Minute))) + require.NoError(t, os.Chtimes(filepath.Join(dirPath, "bar"), time.Time{}, now.Add(-1*time.Minute))) + + d, err := osunix.OpenDir(dirPath, 0) + require.NoError(t, err) + t.Cleanup(func() { assert.NoError(t, d.Close()) }) + + // Stat foo and match contents. + if stat, err := d.StatAt("foo", 0); assert.NoError(t, err) { + assert.Equal(t, "foo", stat.Name()) + assert.Equal(t, int64(5), stat.Size()) + assert.WithinDuration(t, now.Add(-3*time.Minute), stat.ModTime(), 0) + assert.Equal(t, os.FileMode(0644), stat.Mode()) + assert.False(t, stat.IsDir()) + assert.IsType(t, new(unix.Stat_t), stat.Sys()) + } + + // Stat bar and match contents. + if stat, err := d.StatAt("bar", 0); assert.NoError(t, err) { + assert.Equal(t, "bar", stat.Name()) + assert.Greater(t, stat.Size(), int64(0)) + assert.WithinDuration(t, now.Add(-1*time.Minute), stat.ModTime(), 0) + assert.Equal(t, os.FileMode(0755)|os.ModeDir, stat.Mode()) + assert.True(t, stat.IsDir()) + assert.IsType(t, new(unix.Stat_t), stat.Sys()) + } + + // Stat bar/baz and match contents. + if stat, err := d.StatAt(filepath.Join("bar", "baz"), 0); assert.NoError(t, err) { + assert.Equal(t, "baz", stat.Name()) + assert.Equal(t, int64(5), stat.Size()) + assert.WithinDuration(t, now.Add(-2*time.Minute), stat.ModTime(), 0) + assert.Equal(t, os.FileMode(0400), stat.Mode()) + assert.False(t, stat.IsDir()) + assert.IsType(t, new(unix.Stat_t), stat.Sys()) + } + + // List directory contents and match for correctness. + entries, err := d.Readdirnames(10) + if assert.NoError(t, err) && assert.Len(t, entries, 2) { + assert.ElementsMatch(t, entries, []string{"foo", "bar"}) + } + entries, err = d.Readdirnames(10) + assert.Empty(t, entries) + assert.Same(t, io.EOF, err) +} + +func TestDirFD_Entries(t *testing.T) { + dirPath, expectedNames := t.TempDir(), make([]string, 10) + for i := range expectedNames { + expectedNames[i] = strconv.Itoa(i) + require.NoError(t, os.WriteFile(filepath.Join(dirPath, expectedNames[i]), nil, 0644)) + } + + d, err := osunix.OpenDir(dirPath, 0) + require.NoError(t, err) + close := sync.OnceFunc(func() { assert.NoError(t, d.Close()) }) + t.Cleanup(close) + + var names []string + for name, err := range d.ReadEntryNames() { + require.NoError(t, err) + names = append(names, name) + if len(names) >= len(expectedNames)/2 { + break // test early break + } + } + for name, err := range d.ReadEntryNames() { + require.NoError(t, err) + names = append(names, name) + } + + assert.ElementsMatch(t, expectedNames, names) + + for range d.ReadEntryNames() { + require.Fail(t, "Shouldn't yield any additional entries after a full iteration") + } + + close() + next, stop := iter.Pull2(d.ReadEntryNames()) + defer stop() + if name, err, hasNext := next(); assert.True(t, hasNext, "Should yield exactly one error") { + assert.Zero(t, name) + assert.ErrorContains(t, err, "use of closed file") + } + _, _, hasNext := next() + assert.False(t, hasNext, "Should yield exactly one error") +} diff --git a/inttest/reset/clutter-data-dir.sh b/inttest/reset/clutter-data-dir.sh new file mode 100644 index 000000000000..c7c8f1e14d4a --- /dev/null +++ b/inttest/reset/clutter-data-dir.sh @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2024 k0s authors +#shellcheck shell=ash + +set -eu + +make_dir() { mkdir -- "$1" && echo "$1"; } +make_file() { echo "$1" >"$1" && echo "$1"; } + +make_bind_mounts() { + local real="$1" + local target="$2" + + # Directory bind mount + make_dir "$real/real_dir" + make_file "$real/real_dir/real_dir_info.txt" + make_dir "$target/bind_dir" + mount --bind -- "$real/real_dir" "$target/bind_dir" + + # File bind mount + make_file "$real/real_file.txt" + make_file "$target/bind_file.txt" + mount --bind -- "$real/real_file.txt" "$target/bind_file.txt" + + # Recursive directory bind mount + make_dir "$real/real_recursive_dir" + make_file "$real/real_recursive_dir/real_recursive_dir.txt" + make_dir "$real/real_recursive_dir/bind_dir" + mount --bind -- "$real/real_dir" "$real/real_recursive_dir/bind_dir" + make_file "$real/real_recursive_dir/bind_file.txt" + mount --bind -- "$real/real_file.txt" "$real/real_recursive_dir/bind_file.txt" + make_dir "$target/rbind_dir" + mount --rbind -- "$real/real_recursive_dir" "$target/rbind_dir" + + # Directory overmounts + make_dir "$real/overmount_dir" + make_file "$real/overmount_dir/in_overmount_dir.txt" + mount --bind -- "$real/overmount_dir" "$target/bind_dir" + + # File overmounts + make_file "$real/overmount_file.txt" + mount --bind -- "$real/overmount_file.txt" "$target/bind_file.txt" +} + +clutter() { + local dataDir="$1" + local realDir + + realDir="$(mktemp -t -d k0s_reset_inttest.XXXXXX)" + + local dir="$dataDir"/cluttered + make_dir "$dir" + + # Directories and files with restricted permissions + make_dir "$dir/restricted_dir" + make_file "$dir/restricted_dir/no_read_file.txt" + chmod 000 -- "$dir/restricted_dir/no_read_file.txt" # No permissions on the file + make_dir "$dir/restricted_dir/no_exec_dir" + chmod 000 -- "$dir/restricted_dir/no_exec_dir" # No permissions on the directory + make_dir "$dir/restricted_dir/no_exec_nonempty_dir" + make_file "$dir/restricted_dir/no_exec_nonempty_dir/.hidden_file" + chmod 000 -- "$dir/restricted_dir/no_exec_nonempty_dir" # No permissions on the directory + + # Symlinks pointing outside the directory tree + make_dir "$realDir/some_dir" + make_file "$realDir/some_dir/real_file.txt" + ln -s -- "$realDir/some_dir/real_file.txt" "$dir/symlink_to_file" # Symlink to a file + ln -s -- "$realDir/some_dir" "$dir/symlink_to_dir" # Symlink to a directory + + # Bind mounts pointing outside the directory tree + make_bind_mounts "$realDir" "$dir" + + # Bind mounts outside the directory tree pointing into it + # make_bind_mounts "$dir" "$realDir" +} + +clutter "$@" diff --git a/inttest/reset/reset_test.go b/inttest/reset/reset_test.go index e34fcef8d989..d6e5d2a6712d 100644 --- a/inttest/reset/reset_test.go +++ b/inttest/reset/reset_test.go @@ -17,6 +17,11 @@ limitations under the License. package reset import ( + "bytes" + _ "embed" + "fmt" + "io" + "strings" "testing" testifysuite "github.com/stretchr/testify/suite" @@ -28,11 +33,14 @@ type suite struct { common.BootlooseSuite } +//go:embed clutter-data-dir.sh +var clutterScript []byte + func (s *suite) TestReset() { ctx := s.Context() workerNode := s.WorkerNode(0) - if ok := s.Run("k0s gets up", func() { + if !s.Run("k0s gets up", func() { s.Require().NoError(s.InitController(0, "--disable-components=konnectivity-server,metrics-server")) s.Require().NoError(s.RunWorkers()) @@ -44,11 +52,7 @@ func (s *suite) TestReset() { s.T().Log("waiting to see CNI pods ready") s.NoError(common.WaitForKubeRouterReady(ctx, kc), "CNI did not start") - }); !ok { - return - } - s.Run("k0s reset", func() { ssh, err := s.SSH(ctx, workerNode) s.Require().NoError(err) defer ssh.Disconnect() @@ -57,14 +61,47 @@ func (s *suite) TestReset() { s.NoError(ssh.Exec(ctx, "test -d /run/k0s", common.SSHStreams{}), "/run/k0s is not a directory") s.NoError(ssh.Exec(ctx, "pidof containerd-shim-runc-v2 >&2", common.SSHStreams{}), "Expected some running containerd shims") + }) { + return + } + var clutteringPaths bytes.Buffer + + if !s.Run("prepare k0s reset", func() { s.NoError(s.StopWorker(workerNode), "Failed to stop k0s") + ssh, err := s.SSH(ctx, workerNode) + s.Require().NoError(err) + defer ssh.Disconnect() + + streams, flushStreams := common.TestLogStreams(s.T(), "clutter data dir") + streams.In = bytes.NewReader(clutterScript) + streams.Out = io.MultiWriter(&clutteringPaths, streams.Out) + err = ssh.Exec(ctx, "sh -s -- /var/lib/k0s", streams) + flushStreams() + s.Require().NoError(err) + }) { + return + } + + s.Run("k0s reset", func() { + ssh, err := s.SSH(ctx, workerNode) + s.Require().NoError(err) + defer ssh.Disconnect() + streams, flushStreams := common.TestLogStreams(s.T(), "reset") err = ssh.Exec(ctx, "k0s reset --debug", streams) flushStreams() s.NoError(err, "k0s reset didn't exit cleanly") + for _, path := range strings.Split(string(bytes.TrimSpace(clutteringPaths.Bytes())), "\n") { + if strings.HasPrefix(path, "/var/lib/k0s") { + s.NoError(ssh.Exec(ctx, fmt.Sprintf("! test -e %q", path), common.SSHStreams{}), "Failed to verify non-existence of %s", path) + } else { + s.NoError(ssh.Exec(ctx, fmt.Sprintf("test -e %q", path), common.SSHStreams{}), "Failed to verify existence of %s", path) + } + } + // /var/lib/k0s is a mount point in the Docker container and can't be deleted, so it must be empty s.NoError(ssh.Exec(ctx, `x="$(ls -A /var/lib/k0s)" && echo "$x" >&2 && [ -z "$x" ]`, common.SSHStreams{}), "/var/lib/k0s is not empty") s.NoError(ssh.Exec(ctx, "! test -e /run/k0s", common.SSHStreams{}), "/run/k0s still exists") diff --git a/pkg/cleanup/directories_linux.go b/pkg/cleanup/directories_linux.go new file mode 100644 index 000000000000..7a999f1159c1 --- /dev/null +++ b/pkg/cleanup/directories_linux.go @@ -0,0 +1,502 @@ +/* +Copyright 2024 k0s authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cleanup + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "os" + "path/filepath" + "syscall" + + osunix "github.com/k0sproject/k0s/internal/os/unix" + + "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" +) + +type directories struct { + Config *Config +} + +// Name returns the name of the step +func (d *directories) Name() string { + return "remove directories step" +} + +func (d *directories) Run() error { + log := logrus.StandardLogger() + + var errs []error + if err := cleanupBeneath(log, d.Config.dataDir); err != nil { + errs = append(errs, fmt.Errorf("failed to delete data directory: %w", err)) + } + if err := cleanupBeneath(log, d.Config.runDir); err != nil { + errs = append(errs, fmt.Errorf("failed to delete run directory: %w", err)) + } + return errors.Join(errs...) +} + +const ( + cleanupOFlags = unix.O_NOFOLLOW + cleanupAtFlags = unix.AT_NO_AUTOMOUNT | unix.AT_SYMLINK_NOFOLLOW + cleanupResolveFlags = unix.RESOLVE_BENEATH | unix.RESOLVE_NO_MAGICLINKS +) + +// Recursively removes the specified directory. Attempts to do this by making +// sure that everything not in that directory is left untouched, i.e. the +// recursion will not follow any file system links such as symlinks and mount +// points. Instead, any mount points will be unmounted recursively. +// +// Note that this code assumes to be run with elevated privileges. +func cleanupBeneath(log logrus.FieldLogger, dirPath string) (err error) { + // The real path is required as the code may be checking the mount info via + // the proc filesystem. + realDirPath, err := filepath.EvalSymlinks(dirPath) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return err + } + + dir, err := osunix.OpenDir(realDirPath, cleanupOFlags) + if err != nil { + return err + } + defer func() { err = errors.Join(err, dir.Close()) }() + + empty, err := cleanupPathNames(log, dir, realDirPath, true) + if err != nil { + return err + } + if empty { + if err := os.Remove(realDirPath); err != nil && !errors.Is(err, os.ErrNotExist) { + log.WithError(err).Warn("Leaving behind ", realDirPath) + } + } + + return nil +} + +func cleanupPathNames(log logrus.FieldLogger, dir *osunix.DirFD, dirPath string, unlink bool) (bool, error) { + var leftovers bool + for name, err := range dir.ReadEntryNames() { + if err != nil { + return false, fmt.Errorf("failed to enumerate directory entries: %w", err) + } + if !cleanupPathNameLoop(log, dir, dirPath, name, unlink) { + leftovers = true + } + } + + return !leftovers, nil +} + +type cleanupOutcome uint8 + +const ( + cleanIgnored cleanupOutcome = iota + 1 + cleanRetry + cleanUnlinked +) + +func cleanupPathNameLoop(log logrus.FieldLogger, dir *osunix.DirFD, dirPath, name string, unlink bool) bool { + for attempt := 1; ; attempt++ { + outcome, err := cleanupPathName(log, dir, dirPath, name, unlink) + if err == nil { + switch outcome { + case cleanUnlinked: + return true + case cleanIgnored: + return false + case cleanRetry: + if attempt < 256 { + log.Debugf("Retrying %s/%s after attempt %d (unlink=%t)", dirPath, name, attempt, unlink) + continue + } + err = errors.New("too many attempts") + default: + log.WithError(err).Errorf("Unexpected outcome while cleaning up %s/%s: %d", dirPath, name, outcome) + return false + } + } + + if errors.Is(err, os.ErrNotExist) { + return true + } + if errors.Is(err, syscall.EINTR) { + continue + } + + log.WithError(err).Warnf("Leaving behind %s/%s", dirPath, name) + return false + } +} + +func cleanupPathName(log logrus.FieldLogger, dir *osunix.DirFD, dirPath, name string, unlink bool) (_ cleanupOutcome, err error) { + if unlink { + log.Debugf("Trying to unlink %s/%s", dirPath, name) + if outcome, err := unlinkPathName(log, dir, dirPath, name); err != nil { + return 0, err + } else if outcome == unlinkUnlinked { + return cleanUnlinked, nil + } else if outcome == unlinkUnmounted { + // Path has been unmounted. Retry to catch overmounts. + return cleanRetry, nil + } + } + + // Try to recurse into the directory. + log.Debugf("Trying to to open %s/%s", dirPath, name) + subDir, isMountPoint, err := openDirName(dir, dirPath, name) + if err != nil { + // When not unlinking and this is not a directory, + // it might be a mounted file. Try to unmount it. + if !unlink && errors.Is(err, unix.ENOTDIR) { + status, err := getPathNameMountStatus(dir, dirPath, name) + if err != nil { + return 0, err + } + + if status == pathMountStatusRegular { + // Definitely not a mount point. Ignore the file. + return cleanIgnored, nil + } + + err = unmount(log, filepath.Join(dirPath, name)) + if err == nil { + // Path has been unmounted. Retry to catch overmounts. + return cleanRetry, nil + } + if status == pathMountStatusUnknown && errors.Is(err, unix.EINVAL) { + // Not a mount point (or the mount point is locked). + return cleanIgnored, nil + } + return 0, err + } + + return 0, err + } + + close := true + defer func() { + if close { + err = errors.Join(err, subDir.Close()) + } + }() + + // Disable recursive unlink if it's a mount point. + if isMountPoint { + unlink = false + } + + var empty bool + subDirPath := filepath.Join(dirPath, name) + empty, err = cleanupPathNames(log, subDir, subDirPath, unlink) + if err != nil { + return 0, err + } + + // The subDir can be closed now. In fact, it must be closed, so that a + // potential unmount will work. + close = false + if err := subDir.Close(); err != nil { + return 0, err + } + + if isMountPoint { + if err := unmount(log, subDirPath); err != nil { + return 0, err + } + return cleanRetry, nil + } + + if unlink && empty { + if err := dir.RemoveDir(name); err != nil { + return 0, err + } + + return cleanUnlinked, nil + } + + return cleanIgnored, nil +} + +type unlinkOutcome uint8 + +const ( + unlinkUnlinked unlinkOutcome = iota + 1 + unlinkRecurse + unlinkUnmounted +) + +func unlinkPathName(log logrus.FieldLogger, dir *osunix.DirFD, dirPath, name string) (unlinkOutcome, error) { + // First try to simply unlink the name. + // The assumption here is that mount points cannot be simply unlinked. + fileErr := dir.Remove(name) + if fileErr == nil || errors.Is(fileErr, os.ErrNotExist) { + // That worked. Mission accomplished. + return unlinkUnlinked, nil + } + + // Try to remove an empty directory. + dirErr := dir.RemoveDir(name) + switch { + case dirErr == nil: + // That worked. Mission accomplished. + return unlinkUnlinked, nil + + case errors.Is(dirErr, os.ErrExist): + // It's a non-empty directory. + return unlinkRecurse, nil + + case errors.Is(dirErr, unix.ENOTDIR): + // It's not a directory. If it's a mount point, try to unmount it. + if status, err := getPathNameMountStatus(dir, dirPath, name); err != nil { + return 0, errors.Join(fileErr, err) + } else if status != pathMountStatusRegular { + if err := unmount(log, filepath.Join(dirPath, name)); err != nil { + return 0, errors.Join(fileErr, err) + } + return unlinkUnmounted, nil + } + return 0, fileErr + + default: + // Try to clean up recursively for all other errors. + return unlinkRecurse, nil + } +} + +func openDirName(dir *osunix.DirFD, dirPath, name string) (_ *osunix.DirFD, isMountPoint bool, _ error) { + // Try to use openat2 to open it in a way that won't cross mounts. + subDir, err := dir.OpenDir2(name, unix.OpenHow{ + Flags: cleanupOFlags, + Resolve: cleanupResolveFlags | unix.RESOLVE_NO_XDEV, + }) + + // Did we try to cross a mount point? + if errors.Is(err, unix.EXDEV) { + isMountPoint = true + subDir, err = dir.OpenDir2(name, unix.OpenHow{ + Flags: cleanupOFlags, + Resolve: cleanupResolveFlags, + }) + } + + if err == nil || !errors.Is(err, errors.ErrUnsupported) { + return subDir, isMountPoint, err + } + + // Fallback to legacy open. + subDir, err = dir.OpenDir(name, cleanupOFlags) + if err != nil { + return nil, false, err + } + + close := true + defer func() { + if close { + err = errors.Join(err, subDir.Close()) + } + }() + + subDirPath := filepath.Join(dirPath, name) + status, err := getPathMountStatus(dir, subDir, subDirPath) + if err != nil { + return nil, false, err + } + if status == pathMountStatusMountPoint { + isMountPoint = true + } else if status == pathMountStatusUnknown { + // There's still no bullet-proof evidence to rule out that path is + // actually a mount point. As a last resort, have a look at the proc fs. + isMountPoint, err = mountInfoListsMountPoint("/proc/self/mountinfo", subDirPath) + if err != nil { + // The proc filesystem check failed, too. No other checks are left. + // Assume that it's not a mount point. + isMountPoint = false + } + } + + close = false + return subDir, isMountPoint, nil +} + +type pathMountStatus uint8 + +const ( + pathMountStatusUnknown pathMountStatus = iota + pathMountStatusRegular + pathMountStatusMountPoint +) + +func getPathNameMountStatus(dir *osunix.DirFD, dirPath, name string) (pathMountStatus, error) { + if path, err := dir.Open2(name, unix.OpenHow{ + Flags: cleanupOFlags | unix.O_PATH, + Resolve: cleanupResolveFlags | unix.RESOLVE_NO_XDEV, + }); err == nil { + return pathMountStatusRegular, path.Close() + } else if errors.Is(err, unix.EXDEV) { + return pathMountStatusMountPoint, nil + } else if !errors.Is(err, errors.ErrUnsupported) { + return 0, err + } + + path, err := dir.Open(name, cleanupOFlags|unix.O_PATH, 0) + if err != nil { + return 0, err + } + + defer func() { err = errors.Join(err, path.Close()) }() + return getPathMountStatus(dir, path, filepath.Join(dirPath, name)) +} + +func getPathMountStatus(dir *osunix.DirFD, fd osunix.LinuxPath, path string) (pathMountStatus, error) { + // Don't bother to try statx() here. The interesting fields (stx_mnt_id) and + // attributes (STATX_ATTR_MOUNT_ROOT) have been introduced in Linux 5.8, + // whereas openat2() is a thing since Linux 5.6. So its highly unlikely that + // those will be available when openat2() isn't. + + // Check if the paths have different device numbers. + if dirStat, err := dir.StatSelf(); err != nil { + return 0, err + } else if pathStat, err := fd.StatSelf(); err != nil { + return 0, err + } else if dirStat.Dev != pathStat.Dev { + return pathMountStatusMountPoint, nil + } + + // Try to expire the mount point. + err := unix.Unmount(path, unix.MNT_EXPIRE|unix.UMOUNT_NOFOLLOW) + switch { + case errors.Is(err, unix.EINVAL): + // This is the expected error when path is not a mount point. Note that + // there's still the chance that path is referring to a locked mount + // point, i.e. a mount point that is part of a more privileged mount + // namespace than k0s is in. That's not easy to rule out ... + // See https://www.man7.org/linux/man-pages/man2/umount.2.html#ERRORS. + // See https://man7.org/linux/man-pages/man7/mount_namespaces.7.html. + return pathMountStatusUnknown, nil + + case errors.Is(err, unix.EBUSY): + // This is the expected error when path is a mount point. It indicates + // that the resource is in use, which is guaranteed because there's an + // open file descriptor for it. + return pathMountStatusMountPoint, nil + + case errors.Is(err, unix.EAGAIN): + // This is the expected error when path is an unused mount point. This + // shouldn't happen, since there's still an open file descriptor to path. + return 0, &os.PathError{ + Op: "unmount", + Path: path, + Err: fmt.Errorf("supposedly unreachable code path: %w", err), + } + + case errors.Is(err, unix.EPERM): + // This is the expected error if k0s doesn't have the privileges to + // unmount path. Since this code should be run with root privileges, + // this is not expected to happen. Anyhow, don't bail out. + return pathMountStatusUnknown, nil + + case err == nil: + // This means that the path was unmounted, as it has already been + // expired before. This shouldn't happen, since there's still an open + // file descriptor to path. + return 0, &os.PathError{ + Op: "unmount", + Path: path, + Err: errors.New("supposedly unreachable code path: success"), + } + + default: + // Pass on all other errors. + return 0, &os.PathError{Op: "unmount", Path: path, Err: err} + } +} + +// Checks whether path is listed as a mount point in the proc filesystems +// mountinfo file. +// +// https://man7.org/linux/man-pages/man5/proc_pid_mountinfo.5.html +func mountInfoListsMountPoint(mountInfoPath, path string) (bool, error) { + mountInfoBytes, err := os.ReadFile(mountInfoPath) + if err != nil { + return false, err + } + + mountInfoScanner := bufio.NewScanner(bytes.NewReader(mountInfoBytes)) + for mountInfoScanner.Scan() { + // The fifth field is the mount point. + fields := bytes.SplitN(mountInfoScanner.Bytes(), []byte{' '}, 6) + // Some characters are octal-escaped, most notably the space character. + if len(fields) > 5 && equalsOctalsUnsecaped(fields[4], path) { + return true, nil + } + } + + return false, mountInfoScanner.Err() +} + +// Compares if data and str are equal, converting any octal escape sequences of +// the form \NNN in data to their respective ASCII character on the fly. +func equalsOctalsUnsecaped(data []byte, str string) bool { + dlen, slen := len(data), len(str) + + // An escape sequence takes 4 bytes. + // The unescaped length of data is in range [dlen/4, dlen]. + if slen < dlen/4 || slen > dlen { + return false // Lengths don't match, data and str cannot be equal. + } + + doff := 0 + for soff := 0; soff < slen; soff, doff = soff+1, doff+1 { + if doff >= dlen { + return false // str is longer than unescaped data + } + ch := data[doff] + if ch == '\\' && doff < dlen-3 { // The next three bytes should be octal digits. + d1, d2, d3 := data[doff+1]-'0', data[doff+2]-'0', data[doff+3]-'0' + // The ASCII character range is [0, 127] decimal, which corresponds + // to [0, 177] octal. Check if the digits are in range. + if d1 <= 1 && d2 <= 7 && d3 <= 7 { + ch = d1<<6 | d2<<3 | d3 // Convert from octal digits (3 bits per digit). + doff += 3 // Skip the three digits in the next iteration. + } + } + + if str[soff] != ch { + return false + } + } + + return doff == dlen // Both are equal if data has been fully read. +} + +func unmount(log logrus.FieldLogger, path string) error { + log.Debug("Attempting to unmount ", path) + if err := unix.Unmount(path, unix.UMOUNT_NOFOLLOW); err != nil { + return &os.PathError{Op: "unmount", Path: path, Err: err} + } + + log.Info("Unmounted ", path) + return nil +} diff --git a/pkg/cleanup/directories_linux_test.go b/pkg/cleanup/directories_linux_test.go new file mode 100644 index 000000000000..0c0c8b808d12 --- /dev/null +++ b/pkg/cleanup/directories_linux_test.go @@ -0,0 +1,123 @@ +/* +Copyright 2024 k0s authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cleanup + +import ( + "os" + "path/filepath" + "syscall" + "testing" + + "github.com/k0sproject/k0s/internal/os/unix" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCleanupBeneath_NonExistent(t *testing.T) { + log := logrus.New() + log.SetLevel(logrus.DebugLevel) + dir := t.TempDir() + + err := cleanupBeneath(log, filepath.Join(dir, "non-existent")) + assert.NoError(t, err) + assert.DirExists(t, dir) +} + +func TestCleanupBeneath_Symlinks(t *testing.T) { + log := logrus.New() + log.SetLevel(logrus.DebugLevel) + unrelatedDir := t.TempDir() + cleanDir := t.TempDir() + + require.NoError(t, os.WriteFile(filepath.Join(unrelatedDir, "regular_file"), nil, 0644)) + require.NoError(t, os.Mkdir(filepath.Join(unrelatedDir, "regular_dir"), 0755)) + require.NoError(t, os.WriteFile(filepath.Join(unrelatedDir, "regular_dir", "some_file"), nil, 0644)) + + require.NoError(t, os.WriteFile(filepath.Join(cleanDir, "regular_file"), nil, 0644)) + require.NoError(t, os.Mkdir(filepath.Join(cleanDir, "regular_dir"), 0755)) + require.NoError(t, os.WriteFile(filepath.Join(cleanDir, "regular_dir", "some_file"), nil, 0644)) + + require.NoError(t, os.Symlink(filepath.Join(unrelatedDir, "regular_file"), filepath.Join(cleanDir, "symlinked_file"))) + require.NoError(t, os.Symlink(filepath.Join(unrelatedDir, "regular_dir"), filepath.Join(cleanDir, "symlinked_dir"))) + + err := cleanupBeneath(log, filepath.Join(cleanDir)) + assert.NoError(t, err) + assert.NoDirExists(t, cleanDir) + assert.FileExists(t, filepath.Join(unrelatedDir, "regular_file")) + assert.DirExists(t, filepath.Join(unrelatedDir, "regular_dir")) +} + +func TestGetPathMountStatus(t *testing.T) { + parent, err := unix.OpenDir(t.TempDir(), 0) + require.NoError(t, err) + t.Cleanup(func() { assert.NoError(t, parent.Close()) }) + + t.Run("file", func(t *testing.T) { + file, err := parent.Open("file", syscall.O_CREAT, 0644) + require.NoError(t, err) + t.Cleanup(func() { assert.NoError(t, file.Close()) }) + + status, err := getPathMountStatus(parent, file, filepath.Join(parent.Name(), file.Name())) + if assert.NoError(t, err) { + assert.Equal(t, pathMountStatusUnknown, status) + } + }) + + t.Run("dir", func(t *testing.T) { + require.NoError(t, parent.Mkdir("dir", 0755)) + dir, err := parent.OpenDir("dir", 0) + require.NoError(t, err) + t.Cleanup(func() { assert.NoError(t, dir.Close()) }) + + status, err := getPathMountStatus(parent, dir, filepath.Join(parent.Name(), dir.Name())) + if assert.NoError(t, err) { + assert.Equal(t, pathMountStatusUnknown, status) + } + }) +} + +func TestMountInfoListsMountPoint(t *testing.T) { + for _, path := range []string{ + `/`, + `/dev`, + `/sys/fs/bpf`, + `/mnt/path with spaces`, + `/mnt/path\with\backslashes`, + } { + ok, err := mountInfoListsMountPoint("testdata/mountinfo", path) + if assert.NoError(t, err, "For %s", path) { + assert.True(t, ok, "For %s", path) + } + } + + for _, path := range []string{ + ``, + `/de`, + `/dev/`, + `/mnt/path with space`, + `/mnt/path with spaces/`, + `/mnt/path\040with\040spaces`, + `/mnt/path\with\backslash`, + `/mnt/path\with\backslashes/`, + } { + ok, err := mountInfoListsMountPoint("testdata/mountinfo", path) + if assert.NoError(t, err, "For %s", path) { + assert.False(t, ok, "For %s", path) + } + } +} diff --git a/pkg/cleanup/directories.go b/pkg/cleanup/directories_other.go similarity index 96% rename from pkg/cleanup/directories.go rename to pkg/cleanup/directories_other.go index 540a9f928a95..0611515d54c7 100644 --- a/pkg/cleanup/directories.go +++ b/pkg/cleanup/directories_other.go @@ -1,3 +1,5 @@ +//go:build !linux + /* Copyright 2021 k0s authors @@ -81,7 +83,7 @@ func (d *directories) Run() error { return nil } -// this is for checking if the error retrned by os.RemoveAll is due to +// this is for checking if the error returned by os.RemoveAll is due to // it being a mount point. if it is, we can ignore the error. this way // we can't rely on os.RemoveAll instead of recursively deleting the // contents of the directory diff --git a/pkg/cleanup/testdata/mountinfo b/pkg/cleanup/testdata/mountinfo new file mode 100644 index 000000000000..d432a0faf592 --- /dev/null +++ b/pkg/cleanup/testdata/mountinfo @@ -0,0 +1,35 @@ +22 29 0:5 / /dev rw,nosuid shared:14 - devtmpfs devtmpfs rw,size=101288k,nr_inodes=250823,mode=755 +23 22 0:21 / /dev/pts rw,nosuid,noexec,relatime shared:15 - devpts devpts rw,gid=3,mode=620,ptmxmode=666 +24 22 0:22 / /dev/shm rw,nosuid,nodev shared:16 - tmpfs tmpfs rw +25 29 0:23 / /proc rw,nosuid,nodev,noexec,relatime shared:8 - proc proc rw +26 29 0:24 / /run rw,nosuid,nodev shared:17 - tmpfs tmpfs rw,size=506436k,mode=755 +27 26 0:25 / /run/keys rw,nosuid,nodev,relatime shared:18 - ramfs none rw,mode=750 +28 29 0:26 / /sys rw,nosuid,nodev,noexec,relatime shared:9 - sysfs sysfs rw +29 1 253:0 / / rw,relatime shared:1 - ext4 /dev/disk/by-label/nixos rw +30 29 0:27 / /nix/.ro-store rw,relatime shared:2 - 9p nix-store rw,dirsync,loose,access=client,msize=16384,trans=virtio +31 29 0:28 / /nix/.rw-store rw,relatime shared:3 - tmpfs tmpfs rw,mode=755 +34 29 0:29 / /nix/store rw,relatime shared:4 - overlay overlay rw,lowerdir=/mnt-root/nix/.ro-store,upperdir=/mnt-root/nix/.rw-store/upper,workdir=/mnt-root/nix/.rw-store/work +35 29 0:32 / /tmp/shared rw,relatime shared:6 - 9p shared rw,sync,dirsync,access=client,msize=16384,trans=virtio +36 29 0:33 / /tmp/xchg rw,relatime shared:7 - 9p xchg rw,sync,dirsync,access=client,msize=16384,trans=virtio +37 34 0:29 / /nix/store ro,relatime shared:5 - overlay overlay rw,lowerdir=/mnt-root/nix/.ro-store,upperdir=/mnt-root/nix/.rw-store/upper,workdir=/mnt-root/nix/.rw-store/work +38 28 0:6 / /sys/kernel/security rw,nosuid,nodev,noexec,relatime shared:10 - securityfs securityfs rw +39 28 0:34 / /sys/fs/cgroup rw,nosuid,nodev,noexec,relatime shared:11 - cgroup2 cgroup2 rw,nsdelegate,memory_recursiveprot +40 28 0:35 / /sys/fs/pstore rw,nosuid,nodev,noexec,relatime shared:12 - pstore pstore rw +41 28 0:36 / /sys/fs/bpf rw,nosuid,nodev,noexec,relatime shared:13 - bpf bpf rw,mode=700 +42 28 0:7 / /sys/kernel/debug rw,nosuid,nodev,noexec,relatime shared:19 - debugfs debugfs rw +43 22 0:37 / /dev/hugepages rw,nosuid,nodev,relatime shared:20 - hugetlbfs hugetlbfs rw,pagesize=2M +44 22 0:19 / /dev/mqueue rw,nosuid,nodev,noexec,relatime shared:21 - mqueue mqueue rw +68 26 0:38 / /run/credentials/systemd-journald.service ro,nosuid,nodev,noexec,relatime,nosymfollow shared:22 - ramfs none rw,mode=700 +70 26 0:39 / /run/credentials/systemd-tmpfiles-setup-dev-early.service ro,nosuid,nodev,noexec,relatime,nosymfollow shared:23 - ramfs none rw,mode=700 +45 28 0:40 / /sys/kernel/config rw,nosuid,nodev,noexec,relatime shared:24 - configfs configfs rw +46 28 0:41 / /sys/fs/fuse/connections rw,nosuid,nodev,noexec,relatime shared:25 - fusectl fusectl rw +76 26 0:42 / /run/credentials/systemd-tmpfiles-setup-dev.service ro,nosuid,nodev,noexec,relatime,nosymfollow shared:26 - ramfs none rw,mode=700 +78 26 0:43 / /run/credentials/systemd-sysctl.service ro,nosuid,nodev,noexec,relatime,nosymfollow shared:27 - ramfs none rw,mode=700 +109 26 0:45 / /run/credentials/systemd-vconsole-setup.service ro,nosuid,nodev,noexec,relatime,nosymfollow shared:54 - ramfs none rw,mode=700 +49 26 0:46 / /run/wrappers rw,nodev,relatime shared:56 - tmpfs tmpfs rw,mode=755 +115 26 0:47 / /run/credentials/systemd-tmpfiles-setup.service ro,nosuid,nodev,noexec,relatime,nosymfollow shared:58 - ramfs none rw,mode=700 +321 26 0:56 / /run/credentials/getty@tty1.service ro,nosuid,nodev,noexec,relatime,nosymfollow shared:197 - ramfs none rw,mode=700 +55 26 0:57 / /run/credentials/serial-getty@ttyS0.service ro,nosuid,nodev,noexec,relatime,nosymfollow shared:229 - ramfs none rw,mode=700 +61 26 0:58 / /run/user/1000 rw,nosuid,nodev,relatime shared:235 - tmpfs tmpfs rw,size=202572k,nr_inodes=50643,mode=700,uid=1000,gid=999 +67 29 0:59 / /mnt/path\040with\040spaces rw,relatime shared:241 - tmpfs tmpfs rw,size=10240k +75 29 0:60 / /mnt/path\134with\134backslashes rw,relatime shared:247 - tmpfs tmpfs rw,size=10240k