Source code for runway.blueprints.staticsite.auth_at_edge

"""Blueprint for the Authorization@Edge implementation of a Static Site.

Described in detail in this blogpost:
https://aws.amazon.com/blogs/networking-and-content-delivery/authorizationedge-how-to-use-lambdaedge-and-json-web-tokens-to-enhance-web-application-security/

"""
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional

import awacs.logs
import awacs.s3
from awacs.aws import Allow, Principal, Statement
from troposphere import Join, Output, awslambda, cloudfront, iam, s3

from .staticsite import StaticSite

if TYPE_CHECKING:
    from ...cfngin.blueprints.type_defs import BlueprintVariableTypeDef
    from ...context import CfnginContext

LOGGER = logging.getLogger("runway")


[docs]class AuthAtEdge(StaticSite): """Auth@Edge Blueprint.""" AUTH_VARIABLES: Dict[str, BlueprintVariableTypeDef] = { "OAuthScopes": {"type": list, "default": [], "description": "OAuth2 Scopes"}, "PriceClass": { "type": str, "default": "PriceClass_100", # US/Europe "description": "CF price class for the distribution.", }, "RedirectPathSignIn": { "type": str, "default": "/parseauth", "description": "Auth@Edge: The URL that should " "handle the redirect from Cognito " "after sign-in.", }, "RedirectPathAuthRefresh": { "type": str, "default": "/refreshauth", "description": "The URL path that should " "handle the JWT refresh request.", }, "NonSPAMode": { "type": bool, "default": False, "description": "Whether Auth@Edge should omit SPA specific settings", }, "SignOutUrl": { "type": str, "default": "/signout", "description": "The URL path that you can visit to sign-out.", }, } IAM_ARN_PREFIX = "arn:aws:iam::aws:policy/service-role/" VARIABLES: ClassVar[Dict[str, BlueprintVariableTypeDef]] = {}
[docs] def __init__( self, name: str, context: CfnginContext, mappings: Optional[Dict[str, Dict[str, Any]]] = None, description: Optional[str] = None, ) -> None: """Initialize the Blueprint. Args: name: A name for the blueprint. context: Context the blueprint is being executed under. mappings: CloudFormation Mappings to be used in the template. description: Used to describe the resulting CloudFormation template. """ super().__init__( name=name, context=context, description=description, mappings=mappings ) self.VARIABLES.update(StaticSite.VARIABLES) self.VARIABLES.update(self.AUTH_VARIABLES)
[docs] def create_template(self) -> None: """Create the Blueprinted template for Auth@Edge.""" self.template.set_version("2010-09-09") self.template.set_description( "Authorization@Edge Static Website - Bucket, Lambdas, and Distribution" ) # Resources bucket = self.add_bucket() oai = self.add_origin_access_identity() bucket_policy = self.add_cloudfront_bucket_policy(bucket, oai) # TODO Make this available in Auth@Edge lambda_function_associations: List[cloudfront.LambdaFunctionAssociation] = [] if self.directory_index_specified: index_rewrite = self._get_index_rewrite_role_function_and_version() lambda_function_associations = self.get_directory_index_lambda_association( lambda_function_associations, index_rewrite["version"] ) # Auth@Edge Lambdas check_auth_name = "CheckAuth" check_auth_lambda = self.get_auth_at_edge_lambda_and_ver( check_auth_name, "Check Authorization information for request", "check_auth", self.add_lambda_execution_role( "CheckAuthLambdaExecutionRole", check_auth_name ), ) http_headers_name = "HttpHeaders" http_headers_lambda = self.get_auth_at_edge_lambda_and_ver( http_headers_name, "Additional Headers added to every response", "http_headers", self.add_lambda_execution_role( "HttpHeadersLambdaExecutionRole", http_headers_name ), ) parse_auth_name = "ParseAuth" parse_auth_lambda = self.get_auth_at_edge_lambda_and_ver( parse_auth_name, "Parse the Authorization Headers/Cookies for the request", "parse_auth", self.add_lambda_execution_role( "ParseAuthLambdaExecutionRole", parse_auth_name ), ) refresh_auth_name = "RefreshAuth" refresh_auth_lambda = self.get_auth_at_edge_lambda_and_ver( refresh_auth_name, "Refresh the Authorization information when expired", "refresh_auth", self.add_lambda_execution_role( "RefreshAuthLambdaExecutionRole", refresh_auth_name ), ) sign_out_name = "SignOut" sign_out_lambda = self.get_auth_at_edge_lambda_and_ver( sign_out_name, "Sign the User out of the application", "sign_out", self.add_lambda_execution_role("SignOutLambdaExecutionRole", sign_out_name), ) # CloudFront Distribution distribution_options = self.get_distribution_options( bucket, oai, lambda_function_associations, check_auth_lambda["version"], http_headers_lambda["version"], parse_auth_lambda["version"], refresh_auth_lambda["version"], sign_out_lambda["version"], ) self.add_cloudfront_distribution(bucket_policy, distribution_options)
[docs] def get_auth_at_edge_lambda_and_ver( self, title: str, description: str, handle: str, role: iam.Role ) -> Dict[str, Any]: """Create a lambda function and its version. Args: title: The name of the function in PascalCase. description: Description to be displayed in the lambda panel. handle: The underscore separated representation of the name of the lambda. This handle is used to determine the handler for the lambda as well as identify the correct Code hook_data information. role: The Lambda Execution Role. """ function = self.get_auth_at_edge_lambda(title, description, handle, role) return {"function": function, "version": self.add_version(title, function)}
[docs] def get_auth_at_edge_lambda( self, title: str, description: str, handler: str, role: iam.Role ) -> awslambda.Function: """Create an Auth@Edge lambda resource. Args: title: The name of the function in PascalCase. description: Description to be displayed in the lambda panel. handler: The underscore separated representation of the name of the lambda. This handle is used to determine the handler for the lambda as well as identify the correct Code hook_data information. role: The Lambda Execution Role. """ lamb = self.template.add_resource( awslambda.Function( title, DeletionPolicy="Retain", Code=self.context.hook_data["aae_lambda_config"][handler], Description=description, Handler="__init__.handler", Role=role.get_att("Arn"), Runtime="python3.7", ) ) self.template.add_output( Output( f"Lambda{title}Arn", Description=f"Arn For the {title} Lambda Function", Value=lamb.get_att("Arn"), ) ) return lamb
[docs] def add_version( self, title: str, lambda_function: awslambda.Function ) -> awslambda.Version: """Create a version association with a Lambda@Edge function. In order to ensure different versions of the function are appropriately uploaded a hash based on the code of the lambda is appended to the name. As the code changes so will this hash value. Args: title: The name of the function in PascalCase. lambda_function: The Lambda function. """ s3_key = lambda_function.properties["Code"].to_dict()["S3Key"] code_hash = s3_key.split(".")[0].split("-")[-1] return self.template.add_resource( awslambda.Version( title + "Ver" + code_hash, FunctionName=lambda_function.ref() ) )
[docs] def get_distribution_options( self, bucket: s3.Bucket, oai: cloudfront.CloudFrontOriginAccessIdentity, lambda_funcs: List[cloudfront.LambdaFunctionAssociation], check_auth_lambda_version: awslambda.Version, http_headers_lambda_version: awslambda.Version, parse_auth_lambda_version: awslambda.Version, refresh_auth_lambda_version: awslambda.Version, sign_out_lambda_version: awslambda.Version, ) -> Dict[str, Any]: """Retrieve the options for our CloudFront distribution. Keyword Args: bucket: The bucket resource. oai: The origin access identity resource. lambda_funcs: List of Lambda Function associations. check_auth_lambda_version: Lambda Function Version to use. http_headers_lambda_version: Lambda Function Version to use. parse_auth_lambda_version: Lambda Function Version to use. refresh_auth_lambda_version: Lambda Function Version to use. sign_out_lambda_version: Lambda Function Version to use. Return: The CloudFront Distribution Options. """ default_cache_behavior_lambdas = lambda_funcs default_cache_behavior_lambdas.append( cloudfront.LambdaFunctionAssociation( EventType="viewer-request", LambdaFunctionARN=check_auth_lambda_version.ref(), ) ) default_cache_behavior_lambdas.append( cloudfront.LambdaFunctionAssociation( EventType="origin-response", LambdaFunctionARN=http_headers_lambda_version.ref(), ) ) return { "Aliases": self.add_aliases(), "Origins": [ cloudfront.Origin( DomainName=Join(".", [bucket.ref(), "s3.amazonaws.com"]), S3OriginConfig=cloudfront.S3OriginConfig( OriginAccessIdentity=Join( "", ["origin-access-identity/cloudfront/", oai.ref()] ) ), Id="protected-bucket", ) ], "CacheBehaviors": [ cloudfront.CacheBehavior( PathPattern=self.variables["RedirectPathSignIn"], Compress=True, ForwardedValues=cloudfront.ForwardedValues(QueryString=True), LambdaFunctionAssociations=[ cloudfront.LambdaFunctionAssociation( EventType="viewer-request", LambdaFunctionARN=parse_auth_lambda_version.ref(), ) ], TargetOriginId="protected-bucket", ViewerProtocolPolicy="redirect-to-https", ), cloudfront.CacheBehavior( PathPattern=self.variables["RedirectPathAuthRefresh"], Compress=True, ForwardedValues=cloudfront.ForwardedValues(QueryString=True), LambdaFunctionAssociations=[ cloudfront.LambdaFunctionAssociation( EventType="viewer-request", LambdaFunctionARN=refresh_auth_lambda_version.ref(), ) ], TargetOriginId="protected-bucket", ViewerProtocolPolicy="redirect-to-https", ), cloudfront.CacheBehavior( PathPattern=self.variables["SignOutUrl"], Compress=True, ForwardedValues=cloudfront.ForwardedValues(QueryString=True), LambdaFunctionAssociations=[ cloudfront.LambdaFunctionAssociation( EventType="viewer-request", LambdaFunctionARN=sign_out_lambda_version.ref(), ) ], TargetOriginId="protected-bucket", ViewerProtocolPolicy="redirect-to-https", ), ], "DefaultCacheBehavior": cloudfront.DefaultCacheBehavior( AllowedMethods=["GET", "HEAD"], Compress=self.variables.get("Compress", True), DefaultTTL="86400", ForwardedValues=cloudfront.ForwardedValues(QueryString=True), LambdaFunctionAssociations=default_cache_behavior_lambdas, TargetOriginId="protected-bucket", ViewerProtocolPolicy="redirect-to-https", ), "DefaultRootObject": "index.html", "Logging": self.add_logging_bucket(), "PriceClass": self.variables["PriceClass"], "Enabled": True, "WebACLId": self.add_web_acl(), "CustomErrorResponses": self._get_error_responses(), "ViewerCertificate": self.add_acm_cert(), }
def _get_error_responses(self) -> List[cloudfront.CustomErrorResponse]: """Return error response based on site stack variables. When custom_error_responses are defined return those, if running in NonSPAMode return nothing, or return the standard error responses for a SPA. """ if self.variables["custom_error_responses"]: return [ cloudfront.CustomErrorResponse( ErrorCode=response["ErrorCode"], ResponseCode=response["ResponseCode"], ResponsePagePath=response["ResponsePagePath"], ) for response in self.variables["custom_error_responses"] ] if self.variables["NonSPAMode"]: return [] return [ cloudfront.CustomErrorResponse( ErrorCode=404, ResponseCode=200, ResponsePagePath="/index.html" ) ] # pyright: reportIncompatibleMethodOverride=none def _get_cloudfront_bucket_policy_statements( # pylint: disable=arguments-differ self, bucket: s3.Bucket, oai: cloudfront.CloudFrontOriginAccessIdentity ) -> List[Statement]: return [ Statement( Action=[awacs.s3.GetObject], Effect=Allow, Principal=Principal("CanonicalUser", oai.get_att("S3CanonicalUserId")), Resource=[Join("", [bucket.get_att("Arn"), "/*"])], ), Statement( Action=[awacs.s3.ListBucket], Effect=Allow, Principal=Principal("CanonicalUser", oai.get_att("S3CanonicalUserId")), Resource=[bucket.get_att("Arn")], ), ]