Source code for runway.cfngin.hooks.staticsite.auth_at_edge.callback_url_retriever

"""Callback URL Retriever.

Dependency pre-hook responsible for ensuring correct
callback urls are retrieved or a temporary one is used in it's place.

"""
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Dict, Optional

from ...base import HookArgsBaseModel

if TYPE_CHECKING:
    from .....context import CfnginContext

LOGGER = logging.getLogger(__name__)


[docs]class HookArgs(HookArgsBaseModel): """Hook arguments.""" stack_name: str """The name of the stack to check against.""" user_pool_arn: Optional[str] = None """The ARN of the User Pool to check for a client."""
[docs]def get(context: CfnginContext, *__args: Any, **kwargs: Any) -> Dict[str, Any]: """Retrieve the callback URLs for User Pool Client Creation. When the User Pool is created a Callback URL is required. During a post hook entitled ``client_updater`` these Callback URLs are updated to that of the Distribution. Before then we need to ensure that if a Client already exists that the URLs for that client are used to prevent any interruption of service during deploy. Arguments parsed by :class:`~runway.cfngin.hooks.staticsite.auth_at_edge.callback_url_retriever.HookArgs`. Args: context: The context instance. """ args = HookArgs.parse_obj(kwargs) session = context.get_session() cloudformation_client = session.client("cloudformation") cognito_client = session.client("cognito-idp") context_dict = {"callback_urls": ["https://example.org"]} try: # Return the current stack if one exists stack_desc = cloudformation_client.describe_stacks(StackName=args.stack_name) # Get the client_id from the outputs outputs = stack_desc["Stacks"][0]["Outputs"] if args.user_pool_arn: user_pool_id = args.user_pool_arn.split("/")[-1:][0] else: user_pool_id = [ o["OutputValue"] for o in outputs if o["OutputKey"] == "AuthAtEdgeUserPoolId" ][0] client_id = [ o["OutputValue"] for o in outputs if o["OutputKey"] == "AuthAtEdgeClient" ][0] # Poll the user pool client information resp = cognito_client.describe_user_pool_client( UserPoolId=user_pool_id, ClientId=client_id ) # Retrieve the callbacks callbacks = resp["UserPoolClient"]["CallbackURLs"] if callbacks: context_dict["callback_urls"] = callbacks return context_dict except Exception: # pylint: disable=broad-except return context_dict