Skip to content

Commit

Permalink
archive - fixing determination of archive root when root is '/' (#3036)
Browse files Browse the repository at this point in the history
* Initial commit

* Fixing units and path joins

* Ensuring paths are consistently ordered

* Adding changelog fragment

* Using os.path.join to ensure trailing slashes are present

* optimizing use of root in add_targets

* Applying initial review suggestions
  • Loading branch information
Ajpantuso authored Jul 24, 2021
1 parent d057b2e commit 31189e9
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 50 deletions.
4 changes: 4 additions & 0 deletions changelogs/fragments/3036-archive-root-path-fix.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
bugfixes:
- archive - fixing archive root determination when longest common root is ``/``
(https://github.com/ansible-collections/community.general/pull/3036).
95 changes: 45 additions & 50 deletions plugins/modules/files/archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@
LZMA_IMP_ERR = format_exc()
HAS_LZMA = False

PATH_SEP = to_bytes(os.sep)
PY27 = version_info[0:2] >= (2, 7)

STATE_ABSENT = 'absent'
Expand All @@ -213,16 +212,12 @@
STATE_INCOMPLETE = 'incomplete'


def _to_bytes(s):
return to_bytes(s, errors='surrogate_or_strict')


def _to_native(s):
return to_native(s, errors='surrogate_or_strict')

def common_path(paths):
empty = b'' if paths and isinstance(paths[0], six.binary_type) else ''

def _to_native_ascii(s):
return to_native(s, errors='surrogate_or_strict', encoding='ascii')
return os.path.join(
os.path.dirname(os.path.commonprefix([os.path.join(os.path.dirname(p), empty) for p in paths])), empty
)


def expand_paths(paths):
Expand All @@ -239,10 +234,6 @@ def expand_paths(paths):
return expanded_path, is_globby


def is_archive(path):
return re.search(br'\.(tar|tar\.(gz|bz2|xz)|tgz|tbz2|zip)$', os.path.basename(path), re.IGNORECASE)


def legacy_filter(path, exclusion_patterns):
return matches_exclusion_patterns(path, exclusion_patterns)

Expand All @@ -251,6 +242,26 @@ def matches_exclusion_patterns(path, exclusion_patterns):
return any(fnmatch(path, p) for p in exclusion_patterns)


def is_archive(path):
return re.search(br'\.(tar|tar\.(gz|bz2|xz)|tgz|tbz2|zip)$', os.path.basename(path), re.IGNORECASE)


def strip_prefix(prefix, string):
return string[len(prefix):] if string.startswith(prefix) else string


def _to_bytes(s):
return to_bytes(s, errors='surrogate_or_strict')


def _to_native(s):
return to_native(s, errors='surrogate_or_strict')


def _to_native_ascii(s):
return to_native(s, errors='surrogate_or_strict', encoding='ascii')


@six.add_metaclass(abc.ABCMeta)
class Archive(object):
def __init__(self, module):
Expand All @@ -266,7 +277,6 @@ def __init__(self, module):
self.destination_state = STATE_ABSENT
self.errors = []
self.file = None
self.root = b''
self.successes = []
self.targets = []
self.not_found = []
Expand All @@ -275,7 +285,7 @@ def __init__(self, module):
self.expanded_paths, has_globs = expand_paths(paths)
self.expanded_exclude_paths = expand_paths(module.params['exclude_path'])[0]

self.paths = list(set(self.expanded_paths) - set(self.expanded_exclude_paths))
self.paths = sorted(set(self.expanded_paths) - set(self.expanded_exclude_paths))

if not self.paths:
module.fail_json(
Expand All @@ -285,6 +295,8 @@ def __init__(self, module):
msg='Error, no source paths were found'
)

self.root = common_path(self.paths)

if not self.must_archive:
self.must_archive = any([has_globs, os.path.isdir(self.paths[0]), len(self.paths) > 1])

Expand All @@ -298,6 +310,9 @@ def __init__(self, module):
msg='Error, must specify "dest" when archiving multiple files or trees'
)

if self.remove:
self._check_removal_safety()

self.original_size = self.destination_size()

def add(self, path, archive_name):
Expand All @@ -310,9 +325,8 @@ def add(self, path, archive_name):

def add_single_target(self, path):
if self.format in ('zip', 'tar'):
archive_name = re.sub(br'^%s' % re.escape(self.root), b'', path)
self.open()
self.add(path, archive_name)
self.add(path, strip_prefix(self.root, path))
self.close()
self.destination_state = STATE_ARCHIVED
else:
Expand All @@ -333,25 +347,18 @@ def add_single_target(self, path):
def add_targets(self):
self.open()
try:
match_root = re.compile(br'^%s' % re.escape(self.root))
for target in self.targets:
if os.path.isdir(target):
for directory_path, directory_names, file_names in os.walk(target, topdown=True):
if not directory_path.endswith(PATH_SEP):
directory_path += PATH_SEP

for directory_name in directory_names:
full_path = directory_path + directory_name
archive_name = match_root.sub(b'', full_path)
self.add(full_path, archive_name)
full_path = os.path.join(directory_path, directory_name)
self.add(full_path, strip_prefix(self.root, full_path))

for file_name in file_names:
full_path = directory_path + file_name
archive_name = match_root.sub(b'', full_path)
self.add(full_path, archive_name)
full_path = os.path.join(directory_path, file_name)
self.add(full_path, strip_prefix(self.root, full_path))
else:
archive_name = match_root.sub(b'', target)
self.add(target, archive_name)
self.add(target, strip_prefix(self.root, target))
except Exception as e:
if self.format in ('zip', 'tar'):
archive_format = self.format
Expand Down Expand Up @@ -384,26 +391,6 @@ def destination_size(self):

def find_targets(self):
for path in self.paths:
# Use the longest common directory name among all the files as the archive root path
if self.root == b'':
self.root = os.path.dirname(path) + PATH_SEP
else:
for i in range(len(self.root)):
if path[i] != self.root[i]:
break

if i < len(self.root):
self.root = os.path.dirname(self.root[0:i + 1])

self.root += PATH_SEP
# Don't allow archives to be created anywhere within paths to be removed
if self.remove and os.path.isdir(path):
prefix = path if path.endswith(PATH_SEP) else path + PATH_SEP
if self.destination.startswith(prefix):
self.module.fail_json(
path=', '.join(self.paths),
msg='Error, created archive can not be contained in source paths when remove=true'
)
if not os.path.lexists(path):
self.not_found.append(path)
else:
Expand Down Expand Up @@ -470,6 +457,14 @@ def result(self):
'expanded_exclude_paths': [_to_native(p) for p in self.expanded_exclude_paths],
}

def _check_removal_safety(self):
for path in self.paths:
if os.path.isdir(path) and self.destination.startswith(os.path.join(path, b'')):
self.module.fail_json(
path=b', '.join(self.paths),
msg='Error, created archive can not be contained in source paths when remove=true'
)

def _open_compressed_file(self, path, mode):
f = None
if self.format == 'gz':
Expand Down
73 changes: 73 additions & 0 deletions tests/unit/plugins/modules/files/test_archive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)

from __future__ import (absolute_import, division, print_function)
__metaclass__ = type

import pytest

from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.general.tests.unit.compat.mock import Mock, patch
from ansible_collections.community.general.tests.unit.plugins.modules.utils import ModuleTestCase, set_module_args
from ansible_collections.community.general.plugins.modules.files.archive import get_archive, common_path


class TestArchive(ModuleTestCase):
def setUp(self):
super(TestArchive, self).setUp()

self.mock_os_path_isdir = patch('os.path.isdir')
self.os_path_isdir = self.mock_os_path_isdir.start()

def tearDown(self):
self.os_path_isdir = self.mock_os_path_isdir.stop()

def test_archive_removal_safety(self):
set_module_args(
dict(
path=['/foo', '/bar', '/baz'],
dest='/foo/destination.tgz',
remove=True
)
)

module = AnsibleModule(
argument_spec=dict(
path=dict(type='list', elements='path', required=True),
format=dict(type='str', default='gz', choices=['bz2', 'gz', 'tar', 'xz', 'zip']),
dest=dict(type='path'),
exclude_path=dict(type='list', elements='path', default=[]),
exclusion_patterns=dict(type='list', elements='path'),
force_archive=dict(type='bool', default=False),
remove=dict(type='bool', default=False),
),
add_file_common_args=True,
supports_check_mode=True,
)

self.os_path_isdir.side_effect = [True, False, False, True]

module.fail_json = Mock()

archive = get_archive(module)

module.fail_json.assert_called_once_with(
path=b', '.join(archive.paths),
msg='Error, created archive can not be contained in source paths when remove=true'
)


PATHS = (
([], ''),
(['/'], '/'),
([b'/'], b'/'),
(['/foo', '/bar', '/baz', '/foobar', '/barbaz', '/foo/bar'], '/'),
([b'/foo', b'/bar', b'/baz', b'/foobar', b'/barbaz', b'/foo/bar'], b'/'),
(['/foo/bar/baz', '/foo/bar'], '/foo/'),
(['/foo/bar/baz', '/foo/bar/'], '/foo/bar/'),
)


@pytest.mark.parametrize("paths,root", PATHS)
def test_common_path(paths, root):
assert common_path(paths) == root

0 comments on commit 31189e9

Please sign in to comment.