"""Utility functions."""
from __future__ import annotations
import datetime
import hashlib
import importlib
import json
import logging
import os
import platform
import re
import stat
import sys
from contextlib import contextmanager
from decimal import Decimal
from pathlib import Path
from subprocess import check_call
from types import TracebackType
from typing import (
ContextManager, # deprecated in 3.9 for contextlib.AbstractContextManager
)
from typing import (
MutableMapping, # deprecated in 3.9 for collections.abc.MutableMapping
)
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Set,
Type,
Union,
cast,
overload,
)
import yaml
from pydantic import BaseModel as _BaseModel
from typing_extensions import Literal
# make this importable for util as it was before
from ..compat import cached_property # noqa: F401
# make this importable without defining __all__ yet.
# more things need to be moved of this file before starting an explicit __all__.
from ._file_hash import FileHash # noqa: F401
from ._version import Version # noqa: F401
if TYPE_CHECKING:
from mypy_boto3_cloudformation.type_defs import OutputTypeDef
AWS_ENV_VARS = ("AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN")
DOC_SITE = "https://docs.onica.com/projects/runway"
EMBEDDED_LIB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "embedded")
LOGGER = logging.getLogger(__name__)
[docs]class BaseModel(_BaseModel):
"""Base class for Runway models."""
[docs] def get(self, name: str, default: Any = None) -> Any:
"""Safely get the value of an attribute.
Args:
name: Attribute name to return the value for.
default: Value to return if attribute is not found.
"""
return getattr(self, name, default)
[docs] def __contains__(self, name: object) -> bool:
"""Implement evaluation of 'in' conditional.
Args:
name: The name to check for existence in the model.
"""
return name in self.__dict__
[docs] def __getitem__(self, name: str) -> Any:
"""Implement evaluation of self[name].
Args:
name: Attribute name to return the value for.
Returns:
The value associated with the provided name/attribute name.
Raises:
AttributeError: If attribute does not exist on this object.
"""
return getattr(self, name)
[docs] def __setitem__(self, name: str, value: Any) -> None:
"""Implement item assignment (e.g. ``self[name] = value``).
Args:
name: Attribute name to set.
value: Value to assign to the attribute.
"""
super().__setattr__(name, value) # type: ignore
[docs]class JsonEncoder(json.JSONEncoder):
"""Encode Python objects to JSON data.
This class can be used with ``json.dumps()`` to handle most data types
that can occur in responses from AWS.
Usage:
>>> json.dumps(data, cls=JsonEncoder)
"""
[docs] def default(self, o: Any) -> Any:
"""Encode types not supported by the default JSONEncoder.
Args:
o: Object to encode.
Returns:
JSON serializable data type.
Raises:
TypeError: Object type could not be encoded.
"""
if isinstance(o, Decimal):
return float(o)
if isinstance(o, (datetime.datetime, datetime.date)):
return o.isoformat()
if isinstance(o, _BaseModel):
return o.dict()
return super().default(o)
[docs]class MutableMap(MutableMapping[str, Any]):
"""Base class for mutable map objects."""
[docs] def __init__(self, **kwargs: Any) -> None:
"""Initialize class.
Provided ``kwargs`` are added to the object as attributes.
Example:
.. codeblock: python
obj = MutableMap(**{'key': 'value'})
print(obj.__dict__)
# {'key': 'value'}
"""
for key, value in kwargs.items():
if isinstance(value, dict):
setattr(self, key, MutableMap(**value))
else:
setattr(self, key, value)
if kwargs:
self._found_queries = MutableMap()
@property
def data(self) -> Dict[str, Any]:
"""Sanitized output of __dict__.
Removes anything that starts with ``_``.
"""
result: Dict[str, Any] = {}
for key, val in self.__dict__.items():
if key.startswith("_"):
continue
result[key] = val.data if isinstance(val, MutableMap) else val
return result
[docs] def clear_found_cache(self) -> None:
"""Clear _found_cache."""
for _, val in self.__dict__.items():
if isinstance(val, MutableMap):
val.clear_found_cache()
if hasattr(self, "_found_queries"):
self._found_queries.clear()
[docs] def find(self, query: str, default: Any = None, ignore_cache: bool = False) -> Any:
"""Find a value in the map.
Previously found queries are cached to increase search speed. The
cached value should only be used if values are not expected to change.
Args:
query: A period delimited string that is split to search for
nested values
default: The value to return if the query was unsuccessful.
ignore_cache: Ignore cached value.
"""
if not hasattr(self, "_found_queries"):
# if not created from kwargs, this attr won't exist yet
# this is done to prevent endless recursion
self._found_queries = MutableMap()
if not ignore_cache:
cached_result = self._found_queries.get(query, None)
if cached_result:
return cached_result
split_query = query.split(".")
if len(split_query) == 1:
result = self.get(split_query[0], default)
if result != default:
self._found_queries[split_query[0]] = result
return result
nested_value = self.get(split_query[0])
if not nested_value:
if self.get(query):
return self.get(query)
return default
nested_value = nested_value.find(split_query[1])
try:
nested_value = self[split_query[0]].find(
query=".".join(split_query[1:]),
default=default,
ignore_cache=ignore_cache,
)
if nested_value != default:
self._found_queries[query] = nested_value
return nested_value
except (AttributeError, KeyError):
return default
[docs] def get(self, key: str, default: Any = None) -> Any:
"""Implement evaluation of self.get.
Args:
key: Attribute name to return the value for.
default: Value to return if attribute is not found.
"""
return getattr(self, key, default)
[docs] def __bool__(self) -> bool:
"""Implement evaluation of instances as a bool."""
return bool(self.data)
[docs] def __contains__(self, value: Any) -> bool:
"""Implement evaluation of 'in' conditional."""
return value in self.data
[docs] def __getitem__(self, key: str) -> Any:
"""Implement evaluation of self[key].
Args:
key: Attribute name to return the value for.
Returns:
The value associated with the provided key/attribute name.
Raises:
AttributeError: If attribute does not exist on this object.
Example:
.. codeblock: python
obj = MutableMap(**{'key': 'value'})
print(obj['key'])
# value
"""
return getattr(self, key)
[docs] def __setitem__(self, key: str, value: Any) -> None:
"""Implement assignment to self[key].
Args:
key: Attribute name to associate with a value.
value: Value of a key/attribute.
Example:
.. codeblock: python
obj = MutableMap()
obj['key'] = 'value'
print(obj['key'])
# value
"""
if isinstance(value, dict):
setattr(self, key, MutableMap(**value))
else:
setattr(self, key, value)
[docs] def __delitem__(self, key: str) -> None:
"""Implement deletion of self[key].
Args:
key: Attribute name to remove from the object.
Example:
.. codeblock: python
obj = MutableMap(**{'key': 'value'})
del obj['key']
print(obj.__dict__)
# {}
"""
delattr(self, key)
[docs] def __len__(self) -> int:
"""Implement the built-in function len().
Example:
.. codeblock: python
obj = MutableMap(**{'key': 'value'})
print(len(obj))
# 1
"""
return len(self.__dict__)
[docs] def __iter__(self) -> Iterator[Any]:
"""Return iterator object that can iterate over all attributes.
Example:
.. codeblock: python
obj = MutableMap(**{'key': 'value'})
for k, v in obj.items():
print(f'{key}: {value}')
# key: value
"""
return iter(self.__dict__)
[docs] def __str__(self) -> str:
"""Return string representation of the object."""
return json.dumps(self.data, default=json_serial)
[docs]class SafeHaven(ContextManager["SafeHaven"]):
"""Context manager that caches and resets important values on exit.
Caches and resets os.environ, sys.argv, sys.modules, and sys.path.
"""
# pylint: disable=redefined-outer-name
[docs] def __init__(
self,
argv: Optional[Iterable[str]] = None,
environ: Optional[Dict[str, str]] = None,
sys_modules_exclude: Optional[Iterable[str]] = None,
sys_path: Optional[Iterable[str]] = None,
) -> None:
"""Instantiate class.
Args:
argv: Override the value of sys.argv.
environ: Update os.environ.
sys_modules_exclude: A list of modules to exclude when reverting
sys.modules to its previous state.
These are checked as ``module.startswith(name)``.
sys_path: Override the value of sys.path.
"""
self.__os_environ = dict(os.environ)
self.__sys_argv = list(sys.argv)
# deepcopy can't pickle sys.modules and dict()/.copy() are not safe
self.__sys_modules = {}
for k, v in sys.modules.items():
self.__sys_modules[k] = v
self.__sys_path = list(sys.path)
# more informative origin for log statements
self.logger = logging.getLogger("runway." + self.__class__.__name__)
self.sys_modules_exclude: Set[str] = (
set(sys_modules_exclude) if sys_modules_exclude else set()
)
self.sys_modules_exclude.add("runway")
if isinstance(argv, list):
sys.argv = argv
if isinstance(environ, dict):
os.environ.update(environ)
if isinstance(sys_path, list):
sys.path = sys_path
[docs] def reset_all(self) -> None:
"""Reset all values cached by this context manager."""
self.logger.debug("resetting all managed values...")
self.reset_os_environ()
self.reset_sys_argv()
self.reset_sys_modules()
self.reset_sys_path()
[docs] def reset_os_environ(self) -> None:
"""Reset the value of os.environ."""
self.logger.debug("resetting os.environ: %s", json.dumps(self.__os_environ))
os.environ.clear()
os.environ.update(self.__os_environ)
[docs] def reset_sys_argv(self) -> None:
"""Reset the value of sys.argv."""
self.logger.debug("resetting sys.argv: %s", json.dumps(self.__sys_argv))
sys.argv = self.__sys_argv
[docs] def reset_sys_modules(self) -> None:
"""Reset the value of sys.modules."""
self.logger.debug("resetting sys.modules...")
# sys.modules can be manipulated to force reloading modules but,
# replacing it outright does not work as expected
for module in list(sys.modules.keys()):
if module not in self.__sys_modules and not any(
module.startswith(n) for n in self.sys_modules_exclude
):
self.logger.debug(
'removed sys.module: {"%s": "%s"}', module, sys.modules.pop(module)
)
[docs] def reset_sys_path(self) -> None:
"""Reset the value of sys.path."""
self.logger.debug("resetting sys.path: %s", json.dumps(self.__sys_path))
sys.path = self.__sys_path
[docs] def __enter__(self) -> SafeHaven:
"""Enter the context manager.
Returns:
SafeHaven: Instance of the context manager.
"""
self.logger.debug("entering a safe haven...")
return self
[docs] def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""Exit the context manager."""
self.logger.debug("leaving the safe haven...")
self.reset_all()
# TODO remove after https://github.com/yaml/pyyaml/issues/234 is resolved
[docs]class YamlDumper(yaml.Dumper):
"""Custom YAML Dumper.
This Dumper allows for YAML to be output to follow YAML spec 1.2,
example 2.3 of collections (2.1). This provides an output that is more
humanreadable and complies with yamllint.
Example:
>>> print(yaml.dump({'key': ['val1', 'val2']}, Dumper=YamlDumper))
Note:
YAML 1.2 Specification: https://yaml.org/spec/1.2/spec.html
used for reference.
"""
[docs] def increase_indent(self, flow: bool = False, indentless: bool = False) -> None:
"""Override parent method."""
return super().increase_indent(flow, False) # type: ignore
[docs]@contextmanager
def argv(*args: str) -> Iterator[None]:
"""Context manager for temporarily changing sys.argv."""
original_argv = sys.argv.copy()
try:
sys.argv = list(args) # convert tuple to list
yield
finally:
# always restore original value
sys.argv = original_argv
[docs]@contextmanager
def change_dir(newdir: Union[Path, str]) -> Iterator[None]:
"""Change directory.
Adapted from http://stackoverflow.com/a/24176022
Args:
newdir: Path to change directory into.
"""
prevdir = Path.cwd().resolve()
if isinstance(newdir, str):
newdir = Path(newdir)
os.chdir(newdir.resolve())
try:
yield
finally:
os.chdir(prevdir)
[docs]def ensure_file_is_executable(path: str) -> None:
"""Exit if file is not executable.
Args:
path: Path to file.
Raises:
SystemExit: file is not executable.
"""
if platform.system() != "Windows" and (
not stat.S_IXUSR & os.stat(path)[stat.ST_MODE]
):
print(f"Error: File {path} is not executable") # noqa: T201
sys.exit(1)
[docs]def ensure_string(value: Any) -> str:
"""Ensure value is a string."""
if isinstance(value, str):
return value
if isinstance(value, bytes):
return value.decode()
raise TypeError(f"Expected str or bytes but received {type(value)}")
[docs]@contextmanager
def environ(env: Optional[Dict[str, str]] = None, **kwargs: str) -> Iterator[None]:
"""Context manager for temporarily changing os.environ.
The original value of os.environ is restored upon exit.
Args:
env: Dictionary to use when updating os.environ.
"""
env = env or {}
env.update(kwargs)
original_env = dict(os.environ)
os.environ.update(env)
try:
yield
finally:
# always restore original values
os.environ.clear()
os.environ.update(original_env)
[docs]def json_serial(obj: Any) -> Any:
"""JSON serializer for objects not serializable by default json code."""
if isinstance(obj, MutableMap):
return obj.data
raise TypeError(f"Type {type(obj)} not serializable")
[docs]def load_object_from_string(
fqcn: str, try_reload: bool = False
) -> Union[type, Callable[..., Any]]:
"""Convert "." delimited strings to a python object.
Args:
fqcn: A "." delimited string representing the full path to an
object (function, class, variable) inside a module
try_reload: Try to reload the module so any global variables
set within the file during import are reloaded. This only applies
to modules that are already imported and are not builtin.
Returns:
Object being imported from the provided path.
Example::
load_object_from_string("os.path.basename")
load_object_from_string("logging.Logger")
load_object_from_string("LocalClassName")
"""
module_path = "__main__"
object_name = fqcn
if "." in object_name:
module_path, object_name = fqcn.rsplit(".", 1)
if (
try_reload
and sys.modules.get(module_path)
and module_path.split(".")[0]
not in sys.builtin_module_names # skip builtins
):
importlib.reload(sys.modules[module_path])
else:
importlib.import_module(module_path)
return getattr(sys.modules[module_path], object_name)
@overload
def merge_dicts(
dict1: Dict[Any, Any], dict2: Dict[Any, Any], deep_merge: bool = ...
) -> Dict[str, Any]:
...
@overload
def merge_dicts(
dict1: List[Any], dict2: List[Any], deep_merge: bool = ...
) -> List[Any]:
...
[docs]def merge_dicts(
dict1: Union[Dict[Any, Any], List[Any]],
dict2: Union[Dict[Any, Any], List[Any]],
deep_merge: bool = True,
) -> Union[Dict[Any, Any], List[Any]]:
"""Merge dict2 into dict1."""
if deep_merge:
if isinstance(dict1, list) and isinstance(dict2, list):
return dict1 + dict2
if not isinstance(dict1, dict) or not isinstance(dict2, dict):
return dict2
for key in dict2:
dict1[key] = (
merge_dicts(dict1[key], dict2[key], True)
if key in dict1
else dict2[key]
)
return dict1
if isinstance(dict1, dict) and isinstance(dict2, dict):
dict3 = dict1.copy()
dict3.update(dict2)
return dict3
raise ValueError(
f"values of type {type(dict1)} and {type(dict2)} must be type dict"
)
[docs]def snake_case_to_kebab_case(value: str) -> str:
"""Convert snake_case to kebab-case.
Args:
value: The string value to convert.
"""
return value.replace("_", "-")
[docs]def flatten_path_lists(
env_dict: Dict[str, Any], env_root: Optional[str] = None
) -> Dict[str, Any]:
"""Join paths in environment dict down to strings."""
for key, val in env_dict.items():
# Lists are presumed to be path components and will be turned back
# to strings
if isinstance(val, list):
env_dict[key] = (
os.path.join(env_root, os.path.join(*cast(List[str], val)))
if (env_root and not os.path.isabs(os.path.join(*cast(List[str], val))))
else os.path.join(*cast(List[str], val))
)
return env_dict
[docs]def merge_nested_environment_dicts(
env_dicts: Dict[str, Any],
env_name: Optional[str] = None,
env_root: Optional[str] = None,
) -> Dict[str, Any]:
"""Return single-level dictionary from dictionary of dictionaries."""
# If the provided dictionary is just a single "level" (no nested
# environments), it applies to all environments
if all(isinstance(val, (str, list)) for (_key, val) in env_dicts.items()):
return flatten_path_lists(env_dicts, env_root)
if env_name is None:
if env_dicts.get("*"):
return flatten_path_lists(env_dicts.get("*", {}), env_root)
return {}
if not env_dicts.get("*") and not env_dicts.get(env_name):
return {}
combined_dicts = merge_dicts(
cast(Dict[Any, Any], env_dicts.get("*", {})),
cast(Dict[Any, Any], env_dicts.get(env_name, {})),
)
return flatten_path_lists(combined_dicts, env_root)
[docs]def find_cfn_output(key: str, outputs: List[OutputTypeDef]) -> Optional[str]:
"""Return CFN output value.
Args:
key: Name of the output.
outputs: List of Stack outputs.
"""
for i in outputs:
if i.get("OutputKey") == key:
return i.get("OutputValue")
return None
[docs]def get_embedded_lib_path() -> str:
"""Return path of embedded libraries."""
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "embedded")
[docs]def get_hash_for_filename(filename: str, hashfile_path: str) -> str:
"""Return hash for filename in the hashfile."""
filehash = ""
with open(hashfile_path, "r", encoding="utf-8") as stream:
for _cnt, line in enumerate(stream):
if line.rstrip().endswith(filename):
match = re.match(r"^[A-Za-z0-9]*", line)
if match:
filehash = match.group(0)
break
if filehash:
return filehash
raise AttributeError(f"Filename {filename} not found in hash file")
[docs]@contextmanager
def ignore_exit_code_0() -> Iterator[None]:
"""Capture exit calls and ignore those with exit code 0."""
try:
yield
except SystemExit as exit_exc:
if exit_exc.code != 0:
raise
[docs]def fix_windows_command_list(commands: List[str]) -> List[str]:
"""Return command list with working Windows commands.
npm on windows is npm.cmd, which will blow up
subprocess.check_call(['npm', '...'])
Similar issues arise when calling python apps that will have a windows-only
suffix applied to them.
"""
fully_qualified_cmd_path = which(commands[0])
if fully_qualified_cmd_path:
commands[0] = os.path.basename(fully_qualified_cmd_path)
return commands
[docs]def run_commands(
commands: List[Union[str, List[str], Dict[str, Union[str, List[str]]]]],
directory: Union[Path, str],
env: Optional[Dict[str, str]] = None,
) -> None:
"""Run list of commands."""
if env is None:
env = os.environ.copy()
for step in commands:
if isinstance(step, (list, str)):
execution_dir = directory
raw_command = step
elif step.get("command"): # dictionary
execution_dir = (
os.path.join(directory, cast(str, step.get("cwd", "")))
if step.get("cwd")
else directory
)
raw_command = step["command"]
else:
raise AttributeError(f"Invalid command step: {step}")
command_list = (
raw_command.split(" ") if isinstance(raw_command, str) else raw_command
)
if platform.system().lower() == "windows":
command_list = fix_windows_command_list(command_list)
with change_dir(execution_dir):
failed_to_find_error = (
f'Attempted to run "{command_list[0]}" and failed to find it (are you sure it is '
"installed and added to your PATH?)"
)
try:
check_call(command_list, env=env)
except FileNotFoundError:
print(failed_to_find_error, file=sys.stderr) # noqa: T201
sys.exit(1)
[docs]def get_file_hash(
filename: str,
algorithm: Literal[
"blake2b",
"blake2b",
"md5",
"sha1",
"sha224",
"sha256",
"sha3_224",
"sha3_256",
"sha3_384",
"sha3_512",
"sha384",
"sha512",
"shake_128",
"shake_256",
],
) -> str:
"""Return cryptographic hash of file.
.. deprecated:: 2.4.0
Use :class:`runway.utils.FileHash` instead.
"""
LOGGER.warning(
"%s.get_file_hash is deprecated and will be removed in the next major release",
__name__,
)
file_hash = getattr(hashlib, algorithm)()
with open(filename, "rb") as stream:
while True:
data = stream.read(65536) # 64kb chunks
if not data:
break
file_hash.update(data)
return file_hash.hexdigest()
[docs]def md5sum(filename: str) -> str:
"""Return MD5 hash of file.
.. deprecated:: 2.4.0
Use :class:`runway.utils.FileHash` instead.
"""
LOGGER.warning(
"%s.md5sum is deprecated and will be removed in the next major release",
__name__,
)
return get_file_hash(filename, "md5")
[docs]def sha256sum(filename: str) -> str:
"""Return SHA256 hash of file.
.. deprecated:: 2.4.0
Use :class:`runway.utils.FileHash` instead.
"""
LOGGER.warning(
"%s.sha256sum is deprecated and will be removed in the next major release",
__name__,
)
sha256 = hashlib.sha256()
with open(filename, "rb", buffering=0) as stream:
mem_view = memoryview(bytearray(128 * 1024))
for i in iter(lambda: stream.readinto(mem_view), 0):
sha256.update(mem_view[:i])
return sha256.hexdigest()
[docs]@contextmanager
def use_embedded_pkgs(embedded_lib_path: Optional[str] = None) -> Iterator[None]:
"""Temporarily prepend embedded packages to sys.path."""
if embedded_lib_path is None:
embedded_lib_path = get_embedded_lib_path()
old_sys_path = list(sys.path)
sys.path.insert(1, embedded_lib_path) # https://stackoverflow.com/a/10097543
try:
yield
finally:
sys.path = old_sys_path
[docs]def which(program: str) -> Optional[str]:
"""Mimic 'which' command behavior."""
def is_exe(fpath: str) -> bool:
"""Determine if program exists and is executable."""
return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
def get_extensions() -> List[str]:
"""Get PATHEXT if the exist, otherwise use default."""
exts = ".COM;.EXE;.BAT;.CMD;.VBS;.VBE;.JS;.JSE;.WSF;.WSH;.MSC"
if os.getenv("PATHEXT"):
exts = os.environ["PATHEXT"]
return exts.split(";")
fname, file_ext = os.path.splitext(program)
fpath, fname = os.path.split(program)
if not file_ext and platform.system().lower() == "windows":
fnames = [fname + ext for ext in get_extensions()]
else:
fnames = [fname]
for i in fnames:
if fpath:
exe_file = os.path.join(fpath, i)
if is_exe(exe_file):
return exe_file
else:
for path in (
os.environ.get("PATH", "").split(os.pathsep)
if "PATH" in os.environ
else [os.getcwd()]
):
exe_file = os.path.join(path, i)
if is_exe(exe_file):
return exe_file
return None