Source code for runway.core.providers.aws._assume_role

"""Assume an AWS IAM role."""
from __future__ import annotations

import logging
from datetime import datetime
from typing import TYPE_CHECKING, ContextManager, Optional, Type, cast

from typing_extensions import TypedDict

if TYPE_CHECKING:
    from types import TracebackType

    from mypy_boto3_sts.type_defs import AssumedRoleUserTypeDef, CredentialsTypeDef

    from ...._logging import RunwayLogger
    from ....context import RunwayContext


LOGGER = cast("RunwayLogger", logging.getLogger(__name__.replace("._", ".")))

_KwargsTypeDef = TypedDict(
    "_KwargsTypeDef", DurationSeconds=int, RoleArn=str, RoleSessionName=str
)


[docs]class AssumeRole(ContextManager["AssumeRole"]): """Context manager for assuming an AWS role.""" assumed_role_user: AssumedRoleUserTypeDef credentials: CredentialsTypeDef ctx: RunwayContext duration_seconds: int revert_on_exit: bool session_name: str = "runway"
[docs] def __init__( self, context: RunwayContext, role_arn: Optional[str] = None, duration_seconds: Optional[int] = None, revert_on_exit: bool = True, session_name: Optional[str] = None, ): """Instantiate class. Args: context: Runway context object. role_arn: ARN of role to be assumed. duration_seconds: Seconds that the assumed role's credentials will be valid for. (default: 3600) revert_on_exit: Whether credentials in the environment will be reverted upon exiting the context manager. session_name: Name to use for the assumed role session. (default: runway) """ self.assumed_role_user = {"AssumedRoleId": "", "Arn": ""} self.credentials = { "AccessKeyId": "", "Expiration": datetime.now(), "SecretAccessKey": "", "SessionToken": "", } self.role_arn = role_arn self.ctx = context self.duration_seconds = duration_seconds or 3600 self.revert_on_exit = revert_on_exit self.session_name = session_name or "runway"
@property def _kwargs(self) -> _KwargsTypeDef: """Construct keyword arguments to pass to boto3 call.""" return { "DurationSeconds": self.duration_seconds, "RoleArn": self.role_arn or "", "RoleSessionName": self.session_name, }
[docs] def assume(self) -> None: """Perform role assumption.""" if not self.role_arn: LOGGER.debug("no role to assume") return if self.revert_on_exit: self.save_existing_iam_env_vars() sts_client = self.ctx.get_session().client("sts") LOGGER.info("assuming role %s...", self.role_arn) response = sts_client.assume_role(**self._kwargs) LOGGER.debug("sts.assume_role response: %s", response) if "Credentials" in response: self.assumed_role_user.update( response.get("AssumedRoleUser", cast("AssumedRoleUserTypeDef", {})) ) self.credentials.update(response["Credentials"]) self.ctx.env.vars.update( { "AWS_ACCESS_KEY_ID": response["Credentials"]["AccessKeyId"], "AWS_SECRET_ACCESS_KEY": response["Credentials"]["SecretAccessKey"], "AWS_SESSION_TOKEN": response["Credentials"]["SessionToken"], } ) LOGGER.verbose("updated environment with assumed credentials") else: raise ValueError("assume_role did not return Credentials")
[docs] def restore_existing_iam_env_vars(self) -> None: """Restore backed up IAM environment variables.""" if not self.role_arn: LOGGER.debug("no role was assumed; not reverting credentials") return for k in self.ctx.current_aws_creds.keys(): old = "OLD_" + k if self.ctx.env.vars.get(old): self.ctx.env.vars[k] = self.ctx.env.vars.pop(old) LOGGER.debug("reverted environment variables: %s", k) else: self.ctx.env.vars.pop(k, None) LOGGER.debug("removed environment variables: %s ", k)
[docs] def save_existing_iam_env_vars(self) -> None: """Backup IAM environment variables for later restoration.""" for k, v in self.ctx.current_aws_creds.items(): new = "OLD_" + k LOGGER.debug('saving environment variable "%s" as "%s"', k, new) self.ctx.env.vars[new] = cast(str, v)
[docs] def __enter__(self) -> AssumeRole: """Enter the context manager.""" LOGGER.debug("entering aws.AssumeRole context manager...") self.assume() return self
[docs] def __exit__( self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: """Exit the context manager.""" if self.revert_on_exit: self.restore_existing_iam_env_vars() LOGGER.debug("aws.AssumeRole context manager exited")