diff --git a/tests/framework/cli/micropkg/test_micropkg_pull.py b/tests/framework/cli/micropkg/test_micropkg_pull.py index 15067c5d92..120e2f70f4 100644 --- a/tests/framework/cli/micropkg/test_micropkg_pull.py +++ b/tests/framework/cli/micropkg/test_micropkg_pull.py @@ -1,8 +1,8 @@ import filecmp import shutil +import tarfile import textwrap from pathlib import Path -from tarfile import TarInfo from unittest.mock import Mock import pytest @@ -749,6 +749,86 @@ def test_pull_unsupported_protocol_by_fsspec( assert "Trying to use 'pip download'..." in result.output assert error_message in result.output + def test_micropkg_pull_invalid_sdist( + self, fake_project_cli, fake_repo_path, fake_metadata, tmp_path + ): + """ + Test for pulling an invalid sdist file locally with more than one package. + """ + error_message = ( + "Invalid sdist was extracted: exactly one directory was expected" + ) + + call_pipeline_create(fake_project_cli, fake_metadata) + call_micropkg_package(fake_project_cli, fake_metadata) + + sdist_file = ( + fake_repo_path / "dist" / _get_sdist_name(name=PIPELINE_NAME, version="0.1") + ) + assert sdist_file.is_file() + + with tarfile.open(sdist_file, "r:gz") as tar: + tar.extractall(tmp_path) + + # Create extra project + extra_project = tmp_path / f"{PIPELINE_NAME}-0.1_extra" + extra_project.mkdir() + (extra_project / "README.md").touch() + + # Recreate sdist + sdist_file.unlink() + with tarfile.open(sdist_file, "w:gz") as tar: + # Adapted from https://stackoverflow.com/a/65820259/554319 + for fn in tmp_path.iterdir(): + tar.add(fn, arcname=fn.relative_to(tmp_path)) + + result = CliRunner().invoke( + fake_project_cli, + ["micropkg", "pull", str(sdist_file)], + obj=fake_metadata, + ) + assert result.exit_code == 1 + assert error_message in result.stdout + + def test_micropkg_pull_invalid_package_contents( + self, fake_project_cli, fake_repo_path, fake_metadata, tmp_path + ): + """ + Test for pulling an invalid sdist file locally with more than one package. + """ + error_message = "Invalid package contents: exactly one package was expected" + + call_pipeline_create(fake_project_cli, fake_metadata) + call_micropkg_package(fake_project_cli, fake_metadata) + + sdist_file = ( + fake_repo_path / "dist" / _get_sdist_name(name=PIPELINE_NAME, version="0.1") + ) + assert sdist_file.is_file() + + with tarfile.open(sdist_file, "r:gz") as tar: + tar.extractall(tmp_path) + + # Create extra package + extra_package = tmp_path / f"{PIPELINE_NAME}-0.1" / f"{PIPELINE_NAME}_extra" + extra_package.mkdir() + (extra_package / "__init__.py").touch() + + # Recreate sdist + sdist_file.unlink() + with tarfile.open(sdist_file, "w:gz") as tar: + # Adapted from https://stackoverflow.com/a/65820259/554319 + for fn in tmp_path.iterdir(): + tar.add(fn, arcname=fn.relative_to(tmp_path)) + + result = CliRunner().invoke( + fake_project_cli, + ["micropkg", "pull", str(sdist_file)], + obj=fake_metadata, + ) + assert result.exit_code == 1 + assert error_message in result.stdout + @pytest.mark.parametrize( "tar_members,path_name", [ @@ -764,7 +844,7 @@ def test_path_traversal( """Test for checking path traversal attempt in tar file""" tar = Mock() tar.getmembers.return_value = [ - TarInfo(name=tar_name) for tar_name in tar_members + tarfile.TarInfo(name=tar_name) for tar_name in tar_members ] path = Path(path_name) with pytest.raises(Exception, match="Failed to safely extract tar file."):