Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for CloudFront invalidations #12

Merged
merged 6 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.3
rev: v0.5.6
hooks:
- id: ruff
args:
- --fix
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: check-merge-conflict
- id: check-yaml
Expand All @@ -18,6 +18,6 @@ repos:
args:
- --fix=lf
- repo: https://github.com/crate-ci/typos
rev: v1.19.0
rev: v1.23.6
hooks:
- id: typos
37 changes: 37 additions & 0 deletions art/cloudfront.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import annotations

import logging
import time
from typing import Any

log = logging.getLogger(__name__)


# Separated for testing purposes
def get_cloudfront_client() -> Any:
import boto3

return boto3.client("cloudfront")


def execute_cloudfront_invalidations(invalidations: dict[str, set[str]]) -> None:
cf_client = get_cloudfront_client()
ts = int(time.time())
for dist_id, paths in invalidations.items():
log.info("Creating CloudFront invalidation for %s: %d paths", dist_id, len(paths))
caller_reference = f"art-{dist_id}-{ts}"
inv = cf_client.create_invalidation(
DistributionId=dist_id,
InvalidationBatch={
"Paths": {
"Quantity": len(paths),
"Items": sorted(paths),
},
"CallerReference": caller_reference,
},
)
log.info(
"Created CloudFront invalidation with caller reference %s: %s",
caller_reference,
inv["Invalidation"]["Id"],
)
42 changes: 25 additions & 17 deletions art/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import os
import shutil
import tempfile
from typing import Any, Dict, List, Optional
from typing import List, Optional

from art.config import ArtConfig, FileMapEntry
from art.consts import DEFAULT_CONFIG_FILENAME
from art.context import ArtContext
from art.excs import Problem
from art.git import git_clone
from art.manifest import Manifest
Expand Down Expand Up @@ -95,35 +96,38 @@ def run_command(argv: Optional[List[str]] = None) -> None:
args = Args(**vars(ap.parse_args(argv)))
logging.basicConfig(level=(args.log_level or logging.INFO))

config_args: Dict[str, Any] = {"dests": list(args.dests), "name": ""}
is_git = False
if args.git_source:
config_args.update(
work_dir = tempfile.mkdtemp(prefix="art-git-")
atexit.register(shutil.rmtree, work_dir)
config = ArtConfig(
dests=list(args.dests),
name="",
repo_url=args.git_source,
ref=args.git_ref,
work_dir=tempfile.mkdtemp(prefix="art-git-"),
work_dir=work_dir,
)
is_git = True
git_clone(config)
elif args.local_source:
work_dir = os.path.abspath(args.local_source)
config_args.update(
config = ArtConfig(
dests=list(args.dests),
name="",
repo_url=work_dir,
work_dir=work_dir,
)
else:
ap.error("Either a git source or a local source must be defined")

config = ArtConfig(**config_args)

if is_git:
git_clone(config)
atexit.register(shutil.rmtree, config.work_dir)
return
context = ArtContext(
dry_run=bool(args.dry_run),
)

for forked_config in fork_configs_from_work_dir(config, filename=args.config_file):
try:
process_config_postfork(args, forked_config)
process_config_postfork(context, args, forked_config)
except Problem as p:
ap.error(f"config {forked_config.name}: {p}")
context.execute_post_run_tasks()


def clean_dest(dest: str) -> str:
Expand All @@ -132,7 +136,11 @@ def clean_dest(dest: str) -> str:
return dest


def process_config_postfork(args: Args, config: ArtConfig) -> None:
def process_config_postfork(
context: ArtContext,
args: Args,
config: ArtConfig,
) -> None:
if not config.dests:
raise Problem("No destination(s) specified (on command line or in config in source)")
config.dests = [clean_dest(dest) for dest in config.dests]
Expand All @@ -152,12 +160,12 @@ def process_config_postfork(args: Args, config: ArtConfig) -> None:
for dest in config.dests:
for suffix in suffixes:
write(
config,
context=context,
config=config,
dest=dest,
path_suffix=suffix,
manifest=manifest,
wrap_filename=wrap_temp,
dry_run=args.dry_run,
)
if wrap_temp:
os.unlink(wrap_temp)
Expand Down
19 changes: 19 additions & 0 deletions art/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

import dataclasses

from art.cloudfront import execute_cloudfront_invalidations


@dataclasses.dataclass(frozen=True)
class ArtContext:
dry_run: bool = False
_cloudfront_invalidations: dict[str, set[str]] = dataclasses.field(default_factory=dict)

def add_cloudfront_invalidation(self, dist_id: str, path: str) -> None:
self._cloudfront_invalidations.setdefault(dist_id, set()).add(path)

def execute_post_run_tasks(self) -> None:
if self._cloudfront_invalidations:
execute_cloudfront_invalidations(self._cloudfront_invalidations)
self._cloudfront_invalidations.clear()
26 changes: 18 additions & 8 deletions art/s3.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
import logging
from functools import cache
from typing import IO, Any, Dict
from urllib.parse import urlparse

_s3_client = None
from art.context import ArtContext

log = logging.getLogger(__name__)


@cache
def get_s3_client() -> Any:
global _s3_client
if not _s3_client:
import boto3
import boto3

_s3_client = boto3.client("s3")
return _s3_client
return boto3.client("s3")


def s3_write(url: str, source_fp: IO[bytes], *, options: Dict[str, Any], dry_run: bool) -> None:
def s3_write(
url: str,
source_fp: IO[bytes],
*,
options: Dict[str, Any],
context: ArtContext,
) -> None:
purl = urlparse(url)
s3_client = get_s3_client()
assert purl.scheme == "s3"
Expand All @@ -27,8 +33,12 @@ def s3_write(url: str, source_fp: IO[bytes], *, options: Dict[str, Any], dry_run
if acl:
kwargs["ACL"] = acl

if dry_run:
if context.dry_run:
log.info("Dry-run: would write to S3 (ACL %s): %s", acl, url)
return
s3_client.put_object(**kwargs)
log.info("Wrote to S3 (ACL %s): %s", acl, url)

cf_distribution_id = options.get("cf-distribution-id")
if cf_distribution_id:
context.add_cloudfront_invalidation(cf_distribution_id, purl.path)
30 changes: 21 additions & 9 deletions art/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from urllib.parse import parse_qsl

from art.config import ArtConfig
from art.context import ArtContext
from art.manifest import Manifest
from art.s3 import s3_write

Expand All @@ -17,13 +18,13 @@ def _write_file(
dest: str,
source_fp: IO[bytes],
*,
context: ArtContext,
options: Optional[Dict[str, Any]] = None,
dry_run: bool = False,
) -> None:
if options is None:
options = {}
writer = _get_writer_for_dest(dest)
writer(dest, source_fp, options=options, dry_run=dry_run)
writer(dest, source_fp, options=options, context=context)


def _get_writer_for_dest(dest: str) -> Callable: # type: ignore[type-arg]
Expand All @@ -34,8 +35,14 @@ def _get_writer_for_dest(dest: str) -> Callable: # type: ignore[type-arg]
raise ValueError(f"Invalid destination: {dest}")


def local_write(dest: str, source_fp: IO[bytes], *, options: Dict[str, Any], dry_run: bool) -> None:
if dry_run:
def local_write(
dest: str,
source_fp: IO[bytes],
*,
context: ArtContext,
options: Dict[str, Any],
) -> None:
if context.dry_run:
log.info("Dry-run: Would have written local file %s", dest)
return
os.makedirs(os.path.dirname(dest), exist_ok=True)
Expand All @@ -45,12 +52,12 @@ def local_write(dest: str, source_fp: IO[bytes], *, options: Dict[str, Any], dry


def write(
config: ArtConfig,
*,
context: ArtContext,
config: ArtConfig,
dest: str,
path_suffix: str,
manifest: Manifest,
dry_run: bool,
wrap_filename: Optional[str] = None,
) -> None:
options = {}
Expand All @@ -63,20 +70,25 @@ def write(
dest_path = posixpath.join(dest, dest_filename)
local_path = os.path.join(config.work_dir, fileinfo["path"])
with open(local_path, "rb") as infp:
_write_file(dest_path, infp, options=options, dry_run=dry_run)
_write_file(
dest_path,
infp,
context=context,
options=options,
)

_write_file(
dest=posixpath.join(dest, ".manifest.json"),
source_fp=io.BytesIO(manifest.as_json_bytes()),
context=context,
options=options,
dry_run=dry_run,
)

if config.wrap and wrap_filename:
with open(wrap_filename, "rb") as infp:
_write_file(
dest=posixpath.join(dest, config.wrap),
source_fp=infp,
context=context,
options=options,
dry_run=dry_run,
)
45 changes: 42 additions & 3 deletions art_tests/test_s3.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,52 @@
import io
from unittest.mock import Mock

import pytest
from boto3 import _get_default_session

from art import cloudfront
from art.context import ArtContext
from art.s3 import get_s3_client
from art.write import _write_file


def test_s3_acl(mocker):
@pytest.fixture(autouse=True)
def aws_fake_credentials(monkeypatch):
# Makes sure we don't accidentally use real AWS credentials.
monkeypatch.setattr(_get_default_session()._session, "_credentials", Mock())


def test_s3_acl(monkeypatch):
cli = get_s3_client()
cli.put_object = cli.put_object # avoid magic
mocker.patch.object(cli, "put_object")
put_object = Mock()
monkeypatch.setattr(cli, "put_object", put_object)
body = io.BytesIO(b"test")
_write_file("s3://bukkit/key", body, options={"acl": "public-read"})
_write_file("s3://bukkit/key", body, options={"acl": "public-read"}, context=ArtContext())
cli.put_object.assert_called_with(Bucket="bukkit", Key="key", ACL="public-read", Body=body)


def test_s3_invalidate_cloudfront(monkeypatch):
cli = get_s3_client()
cli.put_object = cli.put_object # avoid magic
put_object = Mock()
monkeypatch.setattr(cli, "put_object", put_object)
body = io.BytesIO(b"test")
options = {"acl": "public-read", "cf-distribution-id": "UWUWU"}
context = ArtContext()
_write_file("s3://bukkit/key/foo/bar", body, options=options, context=context)
_write_file("s3://bukkit/key/baz/quux", body, options=options, context=context)
_write_file("s3://bukkit/key/baz/barple", body, options=options, context=context)
cf_client = Mock()
cf_client.create_invalidation.return_value = {"Invalidation": {"Id": "AAAAA"}}
monkeypatch.setattr(cloudfront, "get_cloudfront_client", Mock(return_value=cf_client))
context.execute_post_run_tasks()
# Assert the 3 files get a single invalidation
cf_client.create_invalidation.assert_called_once()
call_kwargs = cf_client.create_invalidation.call_args.kwargs
assert call_kwargs["DistributionId"] == "UWUWU"
assert set(call_kwargs["InvalidationBatch"]["Paths"]["Items"]) == {
"/key/baz/barple",
"/key/baz/quux",
"/key/foo/bar",
}
15 changes: 10 additions & 5 deletions art_tests/test_write.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import unittest.mock

import art.write
from art.config import ArtConfig
from art.context import ArtContext
from art.manifest import Manifest


def test_dest_options(mocker, tmpdir):
def test_dest_options(monkeypatch, tmpdir):
cfg = ArtConfig(work_dir=str(tmpdir), dests=[str(tmpdir)], name="", repo_url=str(tmpdir))
mf = Manifest(files={})
wf = mocker.patch("art.write._write_file")
wf = unittest.mock.MagicMock()
monkeypatch.setattr(art.write, "_write_file", wf)
context = ArtContext(dry_run=False)
art.write.write(
cfg,
config=cfg,
context=context,
dest="derp://foo/bar/?acl=quux",
path_suffix="blag",
manifest=mf,
dry_run=False,
path_suffix="blag",
)
call_kwargs = wf.call_args[1]
assert call_kwargs["options"] == {"acl": "quux"}
Expand Down
Loading
Loading