# Source code for runway.cfngin.dag

```
"""CFNgin directed acyclic graph (DAG) implementation."""
import collections
import logging
from collections import OrderedDict, deque
from copy import copy, deepcopy
from threading import Thread
LOGGER = logging.getLogger(__name__)
[docs]class DAG(object):
"""Directed acyclic graph implementation."""
def __init__(self):
"""Instantiate a new DAG with no nodes or edges."""
self.graph = OrderedDict()
[docs] def add_node(self, node_name):
"""Add a node if it does not exist yet, or error out.
Args:
node_name (str): The unique name of the node to add.
Raises:
KeyError: Raised if a node with the same name already exist in the
graph
"""
graph = self.graph
if node_name in graph:
raise KeyError("node %s already exists" % node_name)
graph[node_name] = set()
[docs] def add_node_if_not_exists(self, node_name):
"""Add a node if it does not exist yet, ignoring duplicates.
Args:
node_name (str): The name of the node to add.
"""
try:
self.add_node(node_name)
except KeyError:
pass
[docs] def delete_node(self, node_name):
"""Delete this node and all edges referencing it.
Args:
node_name (str): The name of the node to delete.
Raises:
KeyError: Raised if the node does not exist in the graph.
"""
graph = self.graph
if node_name not in graph:
raise KeyError("node %s does not exist" % node_name)
graph.pop(node_name)
for _node, edges in graph.items():
if node_name in edges:
edges.remove(node_name)
[docs] def delete_node_if_exists(self, node_name):
"""Delete this node and all edges referencing it.
Ignores any node that is not in the graph, rather than throwing an
exception.
Args:
node_name (str): The name of the node to delete.
"""
try:
self.delete_node(node_name)
except KeyError:
pass
[docs] def add_edge(self, ind_node, dep_node):
"""Add an edge (dependency) between the specified nodes.
Args:
ind_node (str): The independent node to add an edge to.
dep_node (str): The dependent node that has a dependency on the
ind_node.
Raises:
KeyError: Either the ind_node, or dep_node do not exist.
DAGValidationError: Raised if the resulting graph is invalid.
"""
graph = self.graph
if ind_node not in graph:
raise KeyError("independent node %s does not exist" % ind_node)
if dep_node not in graph:
raise KeyError("dependent node %s does not exist" % dep_node)
test_graph = deepcopy(graph)
test_graph[ind_node].add(dep_node)
test_dag = DAG()
test_dag.graph = test_graph
is_valid, message = test_dag.validate()
if is_valid:
graph[ind_node].add(dep_node)
else:
raise DAGValidationError(message)
[docs] def delete_edge(self, ind_node, dep_node):
"""Delete an edge from the graph.
Args:
ind_node (str): The independent node to delete an edge from.
dep_node (str): The dependent node that has a dependency on the
ind_node.
Raises:
KeyError: Raised when the edge doesn't already exist.
"""
graph = self.graph
if dep_node not in graph.get(ind_node, []):
raise KeyError("No edge exists between %s and %s." % (ind_node, dep_node))
graph[ind_node].remove(dep_node)
[docs] def transpose(self):
"""Build a new graph with the edges reversed.
Returns:
:class:`runway.cfngin.dag.DAG`: The transposed graph.
"""
graph = self.graph
transposed = DAG()
for node, edges in graph.items():
transposed.add_node(node)
for node, edges in graph.items():
# for each edge A -> B, transpose it so that B -> A
for edge in edges:
transposed.add_edge(edge, node)
return transposed
[docs] def walk(self, walk_func):
"""Walk each node of the graph in reverse topological order.
This can be used to perform a set of operations, where the next
operation depends on the previous operation. It's important to note
that walking happens serially, and is not parallelized.
Args:
walk_func (:class:`types.FunctionType`): The function to be called
on each node of the graph.
"""
nodes = self.topological_sort()
# Reverse so we start with nodes that have no dependencies.
nodes.reverse()
for node in nodes:
walk_func(node)
[docs] def transitive_reduction(self):
"""Perform a transitive reduction on the DAG.
The transitive reduction of a graph is a graph with as few edges as
possible with the same reachability as the original graph.
See https://en.wikipedia.org/wiki/Transitive_reduction
"""
combinations = []
for node, edges in self.graph.items():
combinations += [[node, edge] for edge in edges]
while True:
new_combinations = []
for comb1 in combinations:
for comb2 in combinations:
if not comb1[-1] == comb2[0]:
continue
new_entry = comb1 + comb2[1:]
if new_entry not in combinations:
new_combinations.append(new_entry)
if not new_combinations:
break
combinations += new_combinations
constructed = {(c[0], c[-1]) for c in combinations if len(c) != 2}
for node, edges in self.graph.items():
bad_nodes = {e for n, e in constructed if node == n}
self.graph[node] = edges - bad_nodes
[docs] def rename_edges(self, old_node_name, new_node_name):
"""Change references to a node in existing edges.
Args:
old_node_name (str): The old name for the node.
new_node_name (str): The new name for the node.
"""
graph = self.graph
for node, edges in graph.items():
if node == old_node_name:
graph[new_node_name] = copy(edges)
del graph[old_node_name]
else:
if old_node_name in edges:
edges.remove(old_node_name)
edges.add(new_node_name)
[docs] def predecessors(self, node):
"""Return a list of all immediate predecessors of the given node.
Args:
node (str): The node whose predecessors you want to find.
Returns:
List[str]: A list of nodes that are immediate predecessors to node.
"""
graph = self.graph
return [key for key in graph if node in graph[key]]
[docs] def downstream(self, node):
"""Return a list of all nodes this node has edges towards.
Args:
node (str): The node whose downstream nodes you want to find.
Returns:
List[str]: A list of nodes that are immediately downstream from the
node.
"""
graph = self.graph
if node not in graph:
raise KeyError("node %s is not in graph" % node)
return list(graph[node])
[docs] def all_downstreams(self, node):
"""Return a list of all nodes downstream in topological order.
Args:
node (str): The node whose downstream nodes you want to find.
Returns:
List[str]: A list of nodes that are downstream from the node.
"""
nodes = [node]
nodes_seen = set()
i = 0
while i < len(nodes):
downstreams = self.downstream(nodes[i])
for downstream_node in downstreams:
if downstream_node not in nodes_seen:
nodes_seen.add(downstream_node)
nodes.append(downstream_node)
i += 1
return [node_ for node_ in self.topological_sort() if node_ in nodes_seen]
[docs] def filter(self, nodes):
"""Return a new DAG with only the given nodes and their dependencies.
Args:
nodes (list): The nodes you are interested in.
Returns:
:class:`DAG`: The filtered graph.
"""
filtered_dag = DAG()
# Add only the nodes we need.
for node in nodes:
filtered_dag.add_node_if_not_exists(node)
for edge in self.all_downstreams(node):
filtered_dag.add_node_if_not_exists(edge)
# Now, rebuild the graph for each node that's present.
for node, edges in self.graph.items():
if node in filtered_dag.graph:
filtered_dag.graph[node] = edges
return filtered_dag
[docs] def all_leaves(self):
"""Return a list of all leaves (nodes with no downstreams).
Returns:
List[str]: A list of all the nodes with no downstreams.
"""
graph = self.graph
return [key for key in graph if not graph[key]]
[docs] def from_dict(self, graph_dict):
"""Reset the graph and build it from the passed dictionary.
The dictionary takes the form of {node_name: [directed edges]}
Args:
graph_dict (Dict[str, Any]): The dictionary used to create the
graph.
Raises:
TypeError: Raised if the value of items in the dict are not lists.
"""
self.reset_graph()
for new_node in graph_dict:
self.add_node(new_node)
for ind_node, dep_nodes in graph_dict.items():
if not isinstance(dep_nodes, collections.Iterable):
raise TypeError("%s: dict values must be lists" % ind_node)
for dep_node in dep_nodes:
self.add_edge(ind_node, dep_node)
[docs] def ind_nodes(self):
"""Return a list of all nodes in the graph with no dependencies.
Returns:
List[str]: A list of all independent nodes.
"""
graph = self.graph
dependent_nodes = set(
node for dependents in graph.values() for node in dependents
)
return [node_ for node_ in graph if node_ not in dependent_nodes]
[docs] def validate(self):
"""Return (Boolean, message) of whether DAG is valid.
Returns:
Tuple[bool, str]
"""
if not self.ind_nodes():
return (False, "no independent nodes detected")
try:
self.topological_sort()
except ValueError as err:
return False, str(err)
return True, "valid"
[docs] def topological_sort(self):
"""Return a topological ordering of the DAG.
Returns:
list: A list of topologically sorted nodes in the graph.
Raises:
ValueError: Raised if the graph is not acyclic.
"""
graph = self.graph
in_degree = {}
for node in graph:
in_degree[node] = 0
for node in graph:
for val in graph[node]:
in_degree[val] += 1
queue = deque()
for node in in_degree:
if in_degree[node] == 0:
queue.appendleft(node)
sorted_graph = []
while queue:
node = queue.pop()
sorted_graph.append(node)
for val in sorted(graph[node]):
in_degree[val] -= 1
if in_degree[val] == 0:
queue.appendleft(val)
if len(sorted_graph) == len(graph):
return sorted_graph
raise ValueError("graph is not acyclic")
def __len__(self):
"""How the length of a DAG is calculated."""
return len(self.graph)
[docs]class ThreadedWalker(object): # pylint: disable=too-few-public-methods
"""Walk a DAG as quickly as the graph topology allows, using threads."""
def __init__(self, semaphore):
"""Instantiate class.
Args:
semaphore (threading.Semaphore): a semaphore object which
can be used to control how many steps are executed in parallel.
"""
self.semaphore = semaphore
[docs] def walk(self, dag, walk_func):
"""Walk each node of the graph, in parallel if it can.
The walk_func is only called when the nodes dependencies have been
satisfied.
"""
# First, we'll topologically sort all of the nodes, with nodes that
# have no dependencies first. We do this to ensure that we don't call
# .join on a thread that hasn't yet been started.
#
# TODO(ejholmes): An alternative would be to ensure that Thread.join
# blocks if the thread has not yet been started.
nodes = dag.topological_sort()
nodes.reverse()
# This maps a node name to a thread of execution.
threads = {}
# Blocks until all of the given nodes have completed execution (whether
# successfully, or errored). Returns True if all nodes returned True.
def wait_for(nodes):
"""Wait for nodes."""
for node in nodes:
thread = threads[node]
while thread.is_alive():
threads[node].join(0.5)
# For each node in the graph, we're going to allocate a thread to
# execute. The thread will block executing walk_func, until all of the
# nodes dependencies have executed.
for node in nodes:
def _fn(node_, deps):
if deps:
LOGGER.debug(
"%s waiting for %s to complete", node_, ", ".join(deps)
)
# Wait for all dependencies to complete.
wait_for(deps)
LOGGER.debug("%s starting", node_)
self.semaphore.acquire()
try:
return walk_func(node_)
finally:
self.semaphore.release()
deps = dag.all_downstreams(node)
threads[node] = Thread(target=_fn, args=(node, deps), name=node)
# Start up all of the threads.
for node in nodes:
threads[node].start()
# Wait for all threads to complete executing.
wait_for(nodes)
```