#17: remove Node._parents

This commit is contained in:
Karl Kroening 2017-07-05 04:07:30 -06:00
parent 7236984626
commit fc07f6c4fa
3 changed files with 41 additions and 30 deletions

View File

@ -55,7 +55,7 @@ def _get_input_args(input_node):
def _get_filter_spec(i, node, stream_name_map):
stream_name = _get_stream_name('v{}'.format(i))
stream_name_map[node] = stream_name
inputs = [stream_name_map[parent] for parent in node._parents]
inputs = [stream_name_map[edge.upstream_node] for edge in node.incoming_edges]
filter_spec = '{}{}{}'.format(''.join(inputs), node._get_filter(), stream_name)
return filter_spec
@ -75,7 +75,8 @@ def _get_global_args(node):
def _get_output_args(node, stream_name_map):
args = []
if node.name != merge_outputs.__name__:
stream_name = stream_name_map[node._parents[0]]
assert len(node.incoming_edges) == 1
stream_name = stream_name_map[node.incoming_edges[0].upstream_node]
if stream_name != '[0]':
args += ['-map', stream_name]
if node.name == output.__name__:
@ -96,7 +97,7 @@ def get_args(node):
"""Get command-line arguments for ffmpeg."""
args = []
# TODO: group nodes together, e.g. `-i somefile -r somerate`.
sorted_nodes, child_map = topo_sort([node])
sorted_nodes, outgoing_edge_maps = topo_sort([node])
del(node)
input_nodes = [node for node in sorted_nodes if isinstance(node, InputNode)]
output_nodes = [node for node in sorted_nodes if isinstance(node, OutputNode) and not

View File

@ -96,6 +96,20 @@ class DagNode(object):
DagEdge = namedtuple('DagEdge', ['downstream_node', 'downstream_label', 'upstream_node', 'upstream_label'])
def get_incoming_edges(downstream_node, incoming_edge_map):
edges = []
for downstream_label, (upstream_node, upstream_label) in incoming_edge_map.items():
edges += [DagEdge(downstream_node, downstream_label, upstream_node, upstream_label)]
return edges
def get_outgoing_edges(upstream_node, outgoing_edge_map):
edges = []
for upstream_label, (downstream_node, downstream_label) in outgoing_edge_map:
edges += [DagEdge(downstream_node, downstream_label, upstream_node, upstream_label)]
return edges
class KwargReprNode(DagNode):
"""A DagNode that can be represented as a set of args+kwargs.
"""
@ -142,11 +156,7 @@ class KwargReprNode(DagNode):
@property
def incoming_edges(self):
edges = []
for downstream_label, (upstream_node, upstream_label) in self.incoming_edge_map.items():
downstream_node = self
edges += [DagEdge(downstream_node, downstream_label, upstream_node, upstream_label)]
return edges
return get_incoming_edges(self, self.incoming_edge_map)
@property
def incoming_edge_map(self):
@ -157,24 +167,29 @@ class KwargReprNode(DagNode):
return self.name
def topo_sort(start_nodes):
def topo_sort(downstream_nodes):
marked_nodes = []
sorted_nodes = []
child_map = {}
def visit(node, child):
if node in marked_nodes:
outgoing_edge_maps = {}
def visit(upstream_node, upstream_label, downstream_node, downstream_label):
if upstream_node in marked_nodes:
raise RuntimeError('Graph is not a DAG')
if child is not None:
if node not in child_map:
child_map[node] = []
child_map[node].append(child)
if node not in sorted_nodes:
marked_nodes.append(node)
parents = [edge.upstream_node for edge in node.incoming_edges]
[visit(parent, node) for parent in parents]
marked_nodes.remove(node)
sorted_nodes.append(node)
unmarked_nodes = list(copy.copy(start_nodes))
if downstream_node is not None:
if upstream_node not in outgoing_edge_maps:
outgoing_edge_maps[upstream_node] = {}
outgoing_edge_maps[upstream_node][upstream_label] = (downstream_node, downstream_label)
if upstream_node not in sorted_nodes:
marked_nodes.append(upstream_node)
for edge in upstream_node.incoming_edges:
visit(edge.upstream_node, edge.upstream_label, edge.downstream_node, edge.downstream_label)
marked_nodes.remove(upstream_node)
sorted_nodes.append(upstream_node)
unmarked_nodes = [(node, 0) for node in downstream_nodes]
while unmarked_nodes:
visit(unmarked_nodes.pop(), None)
return sorted_nodes, child_map
upstream_node, upstream_label = unmarked_nodes.pop()
visit(upstream_node, upstream_label, None, None)
return sorted_nodes, outgoing_edge_maps

View File

@ -13,11 +13,6 @@ class Node(KwargReprNode):
incoming_edge_map[downstream_label] = (upstream_node, upstream_label)
super(Node, self).__init__(incoming_edge_map, name, args, kwargs)
@property
def _parents(self):
# TODO: change graph compilation to use `self.incoming_edges` instead.
return [edge.upstream_node for edge in self.incoming_edges]
class InputNode(Node):
"""InputNode type"""