"""CFNgin hook for syncing static website to S3 bucket."""
from __future__ import annotations
import hashlib
import json
import logging
import os
import time
from operator import itemgetter
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
import yaml
from ....core.providers.aws.s3 import Bucket
from ....module.staticsite.options.models import RunwayStaticSiteExtraFileDataModel
from ....utils import JsonEncoder
from ..base import HookArgsBaseModel
if TYPE_CHECKING:
from boto3.session import Session
from ....context import CfnginContext
LOGGER = logging.getLogger(__name__)
[docs]class HookArgs(HookArgsBaseModel):
"""Hook arguments."""
bucket_name: str
"""S3 bucket name."""
cf_disabled: bool = False
"""Disable the use of CloudFront."""
distribution_domain: str = "undefined"
"""Domain of the CloudFront distribution."""
distribution_id: str = "undefined"
"""CloudFront distribution ID."""
distribution_path: str = "/*"
"""Path in the CloudFront distribution to invalidate."""
extra_files: List[RunwayStaticSiteExtraFileDataModel] = []
"""Extra files to sync to the S3 bucket."""
website_url: Optional[str] = None
"""S3 bucket website URL."""
[docs]def get_archives_to_prune(
archives: List[Dict[str, Any]], hook_data: Dict[str, Any]
) -> List[str]:
"""Return list of keys to delete.
Args:
archives: The full list of file archives
hook_data: CFNgin hook data
"""
files_to_skip = [
hook_data[i]
for i in ["current_archive_filename", "old_archive_filename"]
if hook_data.get(i)
]
archives.sort( # sort from oldest to newest
key=itemgetter("LastModified"), reverse=False
)
# Drop all but last 15 files
return [i["Key"] for i in archives[:-15] if i["Key"] not in files_to_skip]
[docs]def sync(context: CfnginContext, *__args: Any, **kwargs: Any) -> bool:
"""Sync static website to S3 bucket.
Arguments parsed by :class:`~runway.cfngin.hooks.staticsite.upload_staticsite.HookArgs`.
Args:
context: The context instance.
"""
args = HookArgs.parse_obj(kwargs)
session = context.get_session()
build_context = context.hook_data["staticsite"]
invalidate_cache = False
synced_extra_files = sync_extra_files(
context,
args.bucket_name,
args.extra_files,
hash_tracking_parameter=build_context.get("hash_tracking_parameter"),
)
if synced_extra_files:
invalidate_cache = True
if build_context["deploy_is_current"]:
LOGGER.info("skipped upload; latest version already deployed")
else:
bucket = Bucket(context, args.bucket_name)
bucket.sync_from_local(
build_context["app_directory"],
delete=True,
exclude=[f.name for f in args.extra_files if f.name],
)
invalidate_cache = True
if args.cf_disabled:
LOGGER.info("STATIC WEBSITE URL: %s", args.website_url)
elif invalidate_cache:
invalidate_distribution(
session,
identifier=args.distribution_id,
domain=args.distribution_domain,
path=args.distribution_path,
)
LOGGER.info("sync complete")
if not build_context["deploy_is_current"]:
update_ssm_hash(context, session)
prune_archives(context, session)
return True
[docs]def update_ssm_hash(context: CfnginContext, session: Session) -> bool:
"""Update the SSM hash with the new tracking data.
Args:
context: Context instance.
session: boto3 session.
"""
build_context = context.hook_data["staticsite"]
if not build_context.get("hash_tracking_disabled"):
hash_param = build_context["hash_tracking_parameter"]
hash_value = build_context["hash"]
LOGGER.info("updating SSM parameter %s with hash %s", hash_param, hash_value)
set_ssm_value(
session,
hash_param,
hash_value,
"Hash of currently deployed static website source",
)
return True
[docs]def invalidate_distribution(
session: Session,
*,
domain: str = "undefined",
identifier: str,
path: str = "/*",
**_: Any,
) -> bool:
"""Invalidate the current distribution.
Args:
session: The current CFNgin session.
domain: The distribution domain.
identifier: The distribution id.
path: The distribution path.
"""
LOGGER.info("invalidating CloudFront distribution: %s (%s)", identifier, domain)
cf_client = session.client("cloudfront")
cf_client.create_invalidation(
DistributionId=identifier,
InvalidationBatch={
"Paths": {"Quantity": 1, "Items": [path]},
"CallerReference": str(time.time()),
},
)
LOGGER.info("CloudFront invalidation complete")
return True
[docs]def prune_archives(context: CfnginContext, session: Session) -> bool:
"""Prune the archives from the bucket.
Args:
context: The context instance.
session: The CFNgin session.
"""
LOGGER.info("cleaning up old site archives...")
archives: List[Dict[str, Any]] = []
s3_client = session.client("s3")
list_objects_v2_paginator = s3_client.get_paginator("list_objects_v2")
response_iterator = list_objects_v2_paginator.paginate(
Bucket=context.hook_data["staticsite"]["artifact_bucket_name"],
Prefix=context.hook_data["staticsite"]["artifact_key_prefix"],
)
for page in response_iterator:
archives.extend(page.get("Contents", [])) # type: ignore
archives_to_prune = get_archives_to_prune(archives, context.hook_data["staticsite"])
# Iterate in chunks of 1000 to match delete_objects limit
for objects in [
archives_to_prune[i : i + 1000] for i in range(0, len(archives_to_prune), 1000)
]:
s3_client.delete_objects(
Bucket=context.hook_data["staticsite"]["artifact_bucket_name"],
Delete={"Objects": [{"Key": i} for i in objects]},
)
return True
[docs]def auto_detect_content_type(filename: Optional[str]) -> Optional[str]:
"""Auto detects the content type based on the filename.
Args:
filename : A filename to use to auto detect the content type.
Returns:
The content type of the file. None if the content type could not be detected.
"""
if not filename:
return None
_, ext = os.path.splitext(filename)
if ext == ".json":
return "application/json"
if ext in [".yml", ".yaml"]:
return "text/yaml"
return None
[docs]def get_content_type(extra_file: RunwayStaticSiteExtraFileDataModel) -> Optional[str]:
"""Return the content type of the file.
Args:
extra_file: The extra file configuration.
Returns:
The content type of the extra file. If 'content_type' is provided then
that is returned, otherwise it is auto detected based on the name.
"""
return extra_file.content_type or auto_detect_content_type(extra_file.name)
[docs]def get_content(extra_file: RunwayStaticSiteExtraFileDataModel) -> Optional[str]:
"""Get serialized content based on content_type.
Args:
extra_file: The extra file configuration.
Returns:
Serialized content based on the content_type.
"""
if extra_file.content:
if isinstance(extra_file.content, (dict, list)):
if extra_file.content_type == "application/json":
return json.dumps(extra_file.content)
if extra_file.content_type == "text/yaml":
return yaml.safe_dump(extra_file.content)
raise ValueError(
'"content_type" must be json or yaml if "content" is not a string'
)
if not isinstance(extra_file.content, str):
raise TypeError(f"unsupported content: {type(extra_file.content)}")
return cast(Optional[str], extra_file.content)
[docs]def get_ssm_value(session: Session, name: str) -> Optional[str]:
"""Get the ssm parameter value.
Args:
session: The boto3 session.
name: The parameter name.
Returns:
The parameter value.
"""
ssm_client = session.client("ssm")
try:
return ssm_client.get_parameter(Name=name)["Parameter"]["Value"]
except ssm_client.exceptions.ParameterNotFound:
return None
[docs]def set_ssm_value(
session: Session, name: str, value: Any, description: str = ""
) -> None:
"""Set the ssm parameter.
Args:
session: The boto3 session.
name: The name of the parameter.
value: The value of the parameter.
description: A description of the parameter.
"""
ssm_client = session.client("ssm")
ssm_client.put_parameter(
Name=name, Description=description, Value=value, Type="String", Overwrite=True
)