diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cc405d5..937675f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,13 +1,13 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.3.3 + rev: v0.5.6 hooks: - id: ruff args: - --fix - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: check-merge-conflict - id: check-yaml @@ -18,6 +18,6 @@ repos: args: - --fix=lf - repo: https://github.com/crate-ci/typos - rev: v1.19.0 + rev: v1.23.6 hooks: - id: typos diff --git a/art/cloudfront.py b/art/cloudfront.py new file mode 100644 index 0000000..08ec143 --- /dev/null +++ b/art/cloudfront.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import logging +import time +from typing import Any + +log = logging.getLogger(__name__) + + +# Separated for testing purposes +def get_cloudfront_client() -> Any: + import boto3 + + return boto3.client("cloudfront") + + +def execute_cloudfront_invalidations(invalidations: dict[str, set[str]]) -> None: + cf_client = get_cloudfront_client() + ts = int(time.time()) + for dist_id, paths in invalidations.items(): + log.info("Creating CloudFront invalidation for %s: %d paths", dist_id, len(paths)) + caller_reference = f"art-{dist_id}-{ts}" + inv = cf_client.create_invalidation( + DistributionId=dist_id, + InvalidationBatch={ + "Paths": { + "Quantity": len(paths), + "Items": sorted(paths), + }, + "CallerReference": caller_reference, + }, + ) + log.info( + "Created CloudFront invalidation with caller reference %s: %s", + caller_reference, + inv["Invalidation"]["Id"], + ) diff --git a/art/command.py b/art/command.py index 7c89c89..70836a2 100644 --- a/art/command.py +++ b/art/command.py @@ -5,10 +5,11 @@ import os import shutil import tempfile -from typing import Any, Dict, List, Optional +from typing import List, Optional from art.config import ArtConfig, FileMapEntry from art.consts import DEFAULT_CONFIG_FILENAME +from art.context import ArtContext from art.excs import Problem from art.git import git_clone from art.manifest import Manifest @@ -95,35 +96,38 @@ def run_command(argv: Optional[List[str]] = None) -> None: args = Args(**vars(ap.parse_args(argv))) logging.basicConfig(level=(args.log_level or logging.INFO)) - config_args: Dict[str, Any] = {"dests": list(args.dests), "name": ""} - is_git = False if args.git_source: - config_args.update( + work_dir = tempfile.mkdtemp(prefix="art-git-") + atexit.register(shutil.rmtree, work_dir) + config = ArtConfig( + dests=list(args.dests), + name="", repo_url=args.git_source, ref=args.git_ref, - work_dir=tempfile.mkdtemp(prefix="art-git-"), + work_dir=work_dir, ) - is_git = True + git_clone(config) elif args.local_source: work_dir = os.path.abspath(args.local_source) - config_args.update( + config = ArtConfig( + dests=list(args.dests), + name="", repo_url=work_dir, work_dir=work_dir, ) else: ap.error("Either a git source or a local source must be defined") - - config = ArtConfig(**config_args) - - if is_git: - git_clone(config) - atexit.register(shutil.rmtree, config.work_dir) + return + context = ArtContext( + dry_run=bool(args.dry_run), + ) for forked_config in fork_configs_from_work_dir(config, filename=args.config_file): try: - process_config_postfork(args, forked_config) + process_config_postfork(context, args, forked_config) except Problem as p: ap.error(f"config {forked_config.name}: {p}") + context.execute_post_run_tasks() def clean_dest(dest: str) -> str: @@ -132,7 +136,11 @@ def clean_dest(dest: str) -> str: return dest -def process_config_postfork(args: Args, config: ArtConfig) -> None: +def process_config_postfork( + context: ArtContext, + args: Args, + config: ArtConfig, +) -> None: if not config.dests: raise Problem("No destination(s) specified (on command line or in config in source)") config.dests = [clean_dest(dest) for dest in config.dests] @@ -152,12 +160,12 @@ def process_config_postfork(args: Args, config: ArtConfig) -> None: for dest in config.dests: for suffix in suffixes: write( - config, + context=context, + config=config, dest=dest, path_suffix=suffix, manifest=manifest, wrap_filename=wrap_temp, - dry_run=args.dry_run, ) if wrap_temp: os.unlink(wrap_temp) diff --git a/art/context.py b/art/context.py new file mode 100644 index 0000000..7854fdf --- /dev/null +++ b/art/context.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +import dataclasses + +from art.cloudfront import execute_cloudfront_invalidations + + +@dataclasses.dataclass(frozen=True) +class ArtContext: + dry_run: bool = False + _cloudfront_invalidations: dict[str, set[str]] = dataclasses.field(default_factory=dict) + + def add_cloudfront_invalidation(self, dist_id: str, path: str) -> None: + self._cloudfront_invalidations.setdefault(dist_id, set()).add(path) + + def execute_post_run_tasks(self) -> None: + if self._cloudfront_invalidations: + execute_cloudfront_invalidations(self._cloudfront_invalidations) + self._cloudfront_invalidations.clear() diff --git a/art/s3.py b/art/s3.py index c6624c9..75474af 100644 --- a/art/s3.py +++ b/art/s3.py @@ -1,21 +1,27 @@ import logging +from functools import cache from typing import IO, Any, Dict from urllib.parse import urlparse -_s3_client = None +from art.context import ArtContext + log = logging.getLogger(__name__) +@cache def get_s3_client() -> Any: - global _s3_client - if not _s3_client: - import boto3 + import boto3 - _s3_client = boto3.client("s3") - return _s3_client + return boto3.client("s3") -def s3_write(url: str, source_fp: IO[bytes], *, options: Dict[str, Any], dry_run: bool) -> None: +def s3_write( + url: str, + source_fp: IO[bytes], + *, + options: Dict[str, Any], + context: ArtContext, +) -> None: purl = urlparse(url) s3_client = get_s3_client() assert purl.scheme == "s3" @@ -27,8 +33,12 @@ def s3_write(url: str, source_fp: IO[bytes], *, options: Dict[str, Any], dry_run if acl: kwargs["ACL"] = acl - if dry_run: + if context.dry_run: log.info("Dry-run: would write to S3 (ACL %s): %s", acl, url) return s3_client.put_object(**kwargs) log.info("Wrote to S3 (ACL %s): %s", acl, url) + + cf_distribution_id = options.get("cf-distribution-id") + if cf_distribution_id: + context.add_cloudfront_invalidation(cf_distribution_id, purl.path) diff --git a/art/write.py b/art/write.py index 2149eb0..f503a99 100644 --- a/art/write.py +++ b/art/write.py @@ -7,6 +7,7 @@ from urllib.parse import parse_qsl from art.config import ArtConfig +from art.context import ArtContext from art.manifest import Manifest from art.s3 import s3_write @@ -17,13 +18,13 @@ def _write_file( dest: str, source_fp: IO[bytes], *, + context: ArtContext, options: Optional[Dict[str, Any]] = None, - dry_run: bool = False, ) -> None: if options is None: options = {} writer = _get_writer_for_dest(dest) - writer(dest, source_fp, options=options, dry_run=dry_run) + writer(dest, source_fp, options=options, context=context) def _get_writer_for_dest(dest: str) -> Callable: # type: ignore[type-arg] @@ -34,8 +35,14 @@ def _get_writer_for_dest(dest: str) -> Callable: # type: ignore[type-arg] raise ValueError(f"Invalid destination: {dest}") -def local_write(dest: str, source_fp: IO[bytes], *, options: Dict[str, Any], dry_run: bool) -> None: - if dry_run: +def local_write( + dest: str, + source_fp: IO[bytes], + *, + context: ArtContext, + options: Dict[str, Any], +) -> None: + if context.dry_run: log.info("Dry-run: Would have written local file %s", dest) return os.makedirs(os.path.dirname(dest), exist_ok=True) @@ -45,12 +52,12 @@ def local_write(dest: str, source_fp: IO[bytes], *, options: Dict[str, Any], dry def write( - config: ArtConfig, *, + context: ArtContext, + config: ArtConfig, dest: str, path_suffix: str, manifest: Manifest, - dry_run: bool, wrap_filename: Optional[str] = None, ) -> None: options = {} @@ -63,13 +70,18 @@ def write( dest_path = posixpath.join(dest, dest_filename) local_path = os.path.join(config.work_dir, fileinfo["path"]) with open(local_path, "rb") as infp: - _write_file(dest_path, infp, options=options, dry_run=dry_run) + _write_file( + dest_path, + infp, + context=context, + options=options, + ) _write_file( dest=posixpath.join(dest, ".manifest.json"), source_fp=io.BytesIO(manifest.as_json_bytes()), + context=context, options=options, - dry_run=dry_run, ) if config.wrap and wrap_filename: @@ -77,6 +89,6 @@ def write( _write_file( dest=posixpath.join(dest, config.wrap), source_fp=infp, + context=context, options=options, - dry_run=dry_run, ) diff --git a/art_tests/test_s3.py b/art_tests/test_s3.py index daedb90..d37e5f3 100644 --- a/art_tests/test_s3.py +++ b/art_tests/test_s3.py @@ -1,13 +1,52 @@ import io +from unittest.mock import Mock +import pytest +from boto3 import _get_default_session + +from art import cloudfront +from art.context import ArtContext from art.s3 import get_s3_client from art.write import _write_file -def test_s3_acl(mocker): +@pytest.fixture(autouse=True) +def aws_fake_credentials(monkeypatch): + # Makes sure we don't accidentally use real AWS credentials. + monkeypatch.setattr(_get_default_session()._session, "_credentials", Mock()) + + +def test_s3_acl(monkeypatch): cli = get_s3_client() cli.put_object = cli.put_object # avoid magic - mocker.patch.object(cli, "put_object") + put_object = Mock() + monkeypatch.setattr(cli, "put_object", put_object) body = io.BytesIO(b"test") - _write_file("s3://bukkit/key", body, options={"acl": "public-read"}) + _write_file("s3://bukkit/key", body, options={"acl": "public-read"}, context=ArtContext()) cli.put_object.assert_called_with(Bucket="bukkit", Key="key", ACL="public-read", Body=body) + + +def test_s3_invalidate_cloudfront(monkeypatch): + cli = get_s3_client() + cli.put_object = cli.put_object # avoid magic + put_object = Mock() + monkeypatch.setattr(cli, "put_object", put_object) + body = io.BytesIO(b"test") + options = {"acl": "public-read", "cf-distribution-id": "UWUWU"} + context = ArtContext() + _write_file("s3://bukkit/key/foo/bar", body, options=options, context=context) + _write_file("s3://bukkit/key/baz/quux", body, options=options, context=context) + _write_file("s3://bukkit/key/baz/barple", body, options=options, context=context) + cf_client = Mock() + cf_client.create_invalidation.return_value = {"Invalidation": {"Id": "AAAAA"}} + monkeypatch.setattr(cloudfront, "get_cloudfront_client", Mock(return_value=cf_client)) + context.execute_post_run_tasks() + # Assert the 3 files get a single invalidation + cf_client.create_invalidation.assert_called_once() + call_kwargs = cf_client.create_invalidation.call_args.kwargs + assert call_kwargs["DistributionId"] == "UWUWU" + assert set(call_kwargs["InvalidationBatch"]["Paths"]["Items"]) == { + "/key/baz/barple", + "/key/baz/quux", + "/key/foo/bar", + } diff --git a/art_tests/test_write.py b/art_tests/test_write.py index 2a104ad..3beaa86 100644 --- a/art_tests/test_write.py +++ b/art_tests/test_write.py @@ -1,18 +1,23 @@ +import unittest.mock + import art.write from art.config import ArtConfig +from art.context import ArtContext from art.manifest import Manifest -def test_dest_options(mocker, tmpdir): +def test_dest_options(monkeypatch, tmpdir): cfg = ArtConfig(work_dir=str(tmpdir), dests=[str(tmpdir)], name="", repo_url=str(tmpdir)) mf = Manifest(files={}) - wf = mocker.patch("art.write._write_file") + wf = unittest.mock.MagicMock() + monkeypatch.setattr(art.write, "_write_file", wf) + context = ArtContext(dry_run=False) art.write.write( - cfg, + config=cfg, + context=context, dest="derp://foo/bar/?acl=quux", - path_suffix="blag", manifest=mf, - dry_run=False, + path_suffix="blag", ) call_kwargs = wf.call_args[1] assert call_kwargs["options"] == {"acl": "quux"} diff --git a/pyproject.toml b/pyproject.toml index 8a81d75..ba4bb80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ mypy = [ test = [ "build", "pytest-cov", - "pytest-mock~=3.12", "pytest~=8.1", ]