Source code for runway.cfngin.dag

"""CFNgin directed acyclic graph (DAG) implementation."""
from __future__ import annotations

import collections
import collections.abc
import logging
from copy import copy, deepcopy
from threading import Thread
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Iterable,
    List,
    OrderedDict,
    Set,
    Tuple,
    Union,
    cast,
)

if TYPE_CHECKING:
    import threading

LOGGER = logging.getLogger(__name__)


[docs]class DAGValidationError(Exception): """Raised when DAG validation fails."""
[docs]class DAG: """Directed acyclic graph implementation.""" graph: OrderedDict[str, Set[str]]
[docs] def __init__(self) -> None: """Instantiate a new DAG with no nodes or edges.""" self.graph = collections.OrderedDict()
[docs] def add_node(self, node_name: str) -> None: """Add a node if it does not exist yet, or error out. Args: node_name: The unique name of the node to add. Raises: KeyError: Raised if a node with the same name already exist in the graph """ graph = self.graph if node_name in graph: raise KeyError(f"node {node_name} already exists") graph[node_name] = cast(Set[str], set())
[docs] def add_node_if_not_exists(self, node_name: str) -> None: """Add a node if it does not exist yet, ignoring duplicates. Args: node_name: The name of the node to add. """ try: self.add_node(node_name) except KeyError: pass
[docs] def delete_node(self, node_name: str) -> None: """Delete this node and all edges referencing it. Args: node_name: The name of the node to delete. Raises: KeyError: Raised if the node does not exist in the graph. """ graph = self.graph if node_name not in graph: raise KeyError(f"node {node_name} does not exist") graph.pop(node_name) for _node, edges in graph.items(): if node_name in edges: edges.remove(node_name)
[docs] def delete_node_if_exists(self, node_name: str) -> None: """Delete this node and all edges referencing it. Ignores any node that is not in the graph, rather than throwing an exception. Args: node_name: The name of the node to delete. """ try: self.delete_node(node_name) except KeyError: pass
[docs] def add_edge(self, ind_node: str, dep_node: str) -> None: """Add an edge (dependency) between the specified nodes. Args: ind_node: The independent node to add an edge to. dep_node: The dependent node that has a dependency on the ind_node. Raises: KeyError: Either the ind_node, or dep_node do not exist. DAGValidationError: Raised if the resulting graph is invalid. """ graph = self.graph if ind_node not in graph: raise KeyError(f"independent node {ind_node} does not exist") if dep_node not in graph: raise KeyError(f"dependent node {dep_node} does not exist") test_graph = deepcopy(graph) test_graph[ind_node].add(dep_node) test_dag = DAG() test_dag.graph = test_graph is_valid, message = test_dag.validate() if is_valid: graph[ind_node].add(dep_node) else: raise DAGValidationError(message)
[docs] def delete_edge(self, ind_node: str, dep_node: str) -> None: """Delete an edge from the graph. Args: ind_node: The independent node to delete an edge from. dep_node: The dependent node that has a dependency on the ind_node. Raises: KeyError: Raised when the edge doesn't already exist. """ graph = self.graph if dep_node not in graph.get(ind_node, []): raise KeyError(f"No edge exists between {ind_node} and {dep_node}.") graph[ind_node].remove(dep_node)
[docs] def transpose(self) -> DAG: """Build a new graph with the edges reversed.""" graph = self.graph transposed = DAG() for node, _edges in graph.items(): transposed.add_node(node) for node, edges in graph.items(): # for each edge A -> B, transpose it so that B -> A for edge in edges: transposed.add_edge(edge, node) return transposed
[docs] def walk(self, walk_func: Callable[[str], Any]) -> None: """Walk each node of the graph in reverse topological order. This can be used to perform a set of operations, where the next operation depends on the previous operation. It's important to note that walking happens serially, and is not parallelized. Args: walk_func: The function to be called on each node of the graph. """ nodes = self.topological_sort() # Reverse so we start with nodes that have no dependencies. nodes.reverse() for node in nodes: walk_func(node)
[docs] def transitive_reduction(self) -> None: """Perform a transitive reduction on the DAG. The transitive reduction of a graph is a graph with as few edges as possible with the same reachability as the original graph. See https://en.wikipedia.org/wiki/Transitive_reduction """ combinations: List[List[str]] = [] for node, edges in self.graph.items(): combinations += [[node, edge] for edge in edges] while True: new_combinations: List[List[str]] = [] for comb1 in combinations: for comb2 in combinations: if comb1[-1] != comb2[0]: continue new_entry = comb1 + comb2[1:] if new_entry not in combinations: new_combinations.append(new_entry) if not new_combinations: break combinations += new_combinations constructed = {(c[0], c[-1]) for c in combinations if len(c) != 2} for node, edges in self.graph.items(): bad_nodes = {e for n, e in constructed if node == n} self.graph[node] = edges - bad_nodes
[docs] def rename_edges(self, old_node_name: str, new_node_name: str) -> None: """Change references to a node in existing edges. Args: old_node_name: The old name for the node. new_node_name: The new name for the node. """ graph = self.graph for node, edges in graph.items(): if node == old_node_name: graph[new_node_name] = copy(edges) del graph[old_node_name] else: if old_node_name in edges: edges.remove(old_node_name) edges.add(new_node_name)
[docs] def predecessors(self, node: str) -> List[str]: """Return a list of all immediate predecessors of the given node. Args: node (str): The node whose predecessors you want to find. Returns: List[str]: A list of nodes that are immediate predecessors to node. """ graph = self.graph return [key for key in graph if node in graph[key]]
[docs] def downstream(self, node: str) -> List[str]: """Return a list of all nodes this node has edges towards. Args: node: The node whose downstream nodes you want to find. Returns: A list of nodes that are immediately downstream from the node. """ graph = self.graph if node not in graph: raise KeyError(f"node {node} is not in graph") return list(graph[node])
[docs] def all_downstreams(self, node: str) -> List[str]: """Return a list of all nodes downstream in topological order. Args: node: The node whose downstream nodes you want to find. Returns: A list of nodes that are downstream from the node. """ nodes = [node] nodes_seen: Set[str] = set() nodes_iter = nodes for node__ in nodes_iter: downstreams = self.downstream(node__) for downstream_node in downstreams: if downstream_node not in nodes_seen: nodes_seen.add(downstream_node) nodes.append(downstream_node) return [node_ for node_ in self.topological_sort() if node_ in nodes_seen]
[docs] def filter(self, nodes: List[str]) -> DAG: """Return a new DAG with only the given nodes and their dependencies. Args: nodes: The nodes you are interested in. """ filtered_dag = DAG() # Add only the nodes we need. for node in nodes: filtered_dag.add_node_if_not_exists(node) for edge in self.all_downstreams(node): filtered_dag.add_node_if_not_exists(edge) # Now, rebuild the graph for each node that's present. for node, edges in self.graph.items(): if node in filtered_dag.graph: filtered_dag.graph[node] = edges return filtered_dag
[docs] def all_leaves(self) -> List[str]: """Return a list of all leaves (nodes with no downstreams).""" graph = self.graph return [key for key in graph if not graph[key]]
[docs] def from_dict(self, graph_dict: Dict[str, Union[Iterable[str], Any]]) -> None: """Reset the graph and build it from the passed dictionary. The dictionary takes the form of {node_name: [directed edges]} Args: graph_dict: The dictionary used to create the graph. Raises: TypeError: Raised if the value of items in the dict are not lists. """ self.reset_graph() for new_node in graph_dict: self.add_node(new_node) for ind_node, dep_nodes in graph_dict.items(): if not isinstance(dep_nodes, collections.abc.Iterable): raise TypeError(f"{ind_node}: dict values must be lists") for dep_node in dep_nodes: self.add_edge(ind_node, dep_node)
[docs] def reset_graph(self) -> None: """Restore the graph to an empty state.""" self.graph = collections.OrderedDict()
[docs] def ind_nodes(self) -> List[str]: """Return a list of all nodes in the graph with no dependencies.""" graph = self.graph dependent_nodes = {node for dependents in graph.values() for node in dependents} return [node_ for node_ in graph if node_ not in dependent_nodes]
[docs] def validate(self) -> Tuple[bool, str]: """Return (Boolean, message) of whether DAG is valid.""" if not self.ind_nodes(): return (False, "no independent nodes detected") try: self.topological_sort() except ValueError as err: return False, str(err) return True, "valid"
[docs] def topological_sort(self) -> List[str]: """Return a topological ordering of the DAG. Raises: ValueError: Raised if the graph is not acyclic. """ graph = self.graph in_degree = {node: 0 for node in graph} for node in graph: for val in graph[node]: in_degree[val] += 1 queue: "collections.deque[str]" = collections.deque() for node, value in in_degree.items(): if value == 0: queue.appendleft(node) sorted_graph: List[str] = [] while queue: node = queue.pop() sorted_graph.append(node) for val in sorted(graph[node]): in_degree[val] -= 1 if in_degree[val] == 0: queue.appendleft(val) if len(sorted_graph) == len(graph): return sorted_graph raise ValueError("graph is not acyclic")
[docs] def size(self) -> int: """Count of nodes in the graph.""" return len(self)
[docs] def __len__(self) -> int: """How the length of a DAG is calculated.""" return len(self.graph)
[docs]def walk(dag: DAG, walk_func: Callable[[str], Any]) -> None: """Walk a DAG.""" return dag.walk(walk_func)
[docs]class UnlimitedSemaphore: """threading.Semaphore, but acquire always succeeds."""
[docs] def acquire(self, *args: Any) -> Any: """Do nothing."""
[docs] def release(self) -> Any: """Do nothing."""
[docs]class ThreadedWalker: """Walk a DAG as quickly as the graph topology allows, using threads."""
[docs] def __init__(self, semaphore: Union[threading.Semaphore, UnlimitedSemaphore]): """Instantiate class. Args: semaphore: A semaphore object which can be used to control how many steps are executed in parallel. """ self.semaphore = semaphore
[docs] def walk(self, dag: DAG, walk_func: Callable[[str], Any]) -> None: """Walk each node of the graph, in parallel if it can. The walk_func is only called when the nodes dependencies have been satisfied. """ # First, we'll topologically sort all of the nodes, with nodes that # have no dependencies first. We do this to ensure that we don't call # .join on a thread that hasn't yet been started. # # TODO(ejholmes): An alternative would be to ensure that Thread.join # blocks if the thread has not yet been started. nodes = dag.topological_sort() nodes.reverse() # This maps a node name to a thread of execution. threads: Dict[str, Any] = {} # Blocks until all of the given nodes have completed execution (whether # successfully, or errored). Returns True if all nodes returned True. def wait_for(nodes: List[str]): """Wait for nodes.""" for node in nodes: thread = threads[node] while thread.is_alive(): threads[node].join(0.5) # For each node in the graph, we're going to allocate a thread to # execute. The thread will block executing walk_func, until all of the # nodes dependencies have executed. for node in nodes: def _fn(node_: str, deps: List[str]) -> Any: if deps: LOGGER.debug( "%s waiting for %s to complete", node_, ", ".join(deps) ) # Wait for all dependencies to complete. wait_for(deps) LOGGER.debug("%s starting", node_) self.semaphore.acquire() try: return walk_func(node_) finally: self.semaphore.release() deps = dag.all_downstreams(node) threads[node] = Thread(target=_fn, args=(node, deps), name=node) # Start up all of the threads. for node in nodes: threads[node].start() # Wait for all threads to complete executing. wait_for(nodes)