Take into account upstream selectors in topological sort, get_args() and view()

This commit is contained in:
Davide Depau 2017-12-21 15:35:39 +01:00
parent f6d014540a
commit 3f671218a6
No known key found for this signature in database
GPG Key ID: C7D999B6A55EFE86
3 changed files with 22 additions and 13 deletions

View File

@ -23,7 +23,7 @@ from .nodes import (
def _get_stream_name(name):
return '[{}]'.format(name)
return '{}'.format(name)
def _convert_kwargs_to_cmd_line_args(kwargs):
@ -57,8 +57,8 @@ def _get_input_args(input_node):
def _get_filter_spec(node, outgoing_edge_map, stream_name_map):
incoming_edges = node.incoming_edges
outgoing_edges = get_outgoing_edges(node, outgoing_edge_map)
inputs = [stream_name_map[edge.upstream_node, edge.upstream_label] for edge in incoming_edges]
outputs = [stream_name_map[edge.upstream_node, edge.upstream_label] for edge in outgoing_edges]
inputs = ["[{}{}]".format(stream_name_map[edge.upstream_node, edge.upstream_label], "" if not edge.upstream_selector else ":{}".format(edge.upstream_selector)) for edge in incoming_edges]
outputs = ["[{}]".format(stream_name_map[edge.upstream_node, edge.upstream_label]) for edge in outgoing_edges]
filter_spec = '{}{}{}'.format(''.join(inputs), node._get_filter(outgoing_edges), ''.join(outputs))
return filter_spec
@ -95,7 +95,7 @@ def _get_output_args(node, stream_name_map):
args = []
assert len(node.incoming_edges) == 1
edge = node.incoming_edges[0]
stream_name = stream_name_map[edge.upstream_node, edge.upstream_label]
stream_name = "[{}{}]".format(stream_name_map[edge.upstream_node, edge.upstream_label], "" if not edge.upstream_selector else ":{}".format(edge.upstream_selector))
if stream_name != '[0]':
args += ['-map', stream_name]
kwargs = copy.copy(node.kwargs)

View File

@ -62,9 +62,13 @@ def view(stream_spec, **kwargs):
kwargs = {}
up_label = edge.upstream_label
down_label = edge.downstream_label
if show_labels and (up_label is not None or down_label is not None):
up_selector = edge.upstream_selector
if show_labels and (up_label is not None or down_label is not None or up_selector is not None):
if up_label is None:
up_label = ''
if up_selector is not None:
up_label += ":" + up_selector
if down_label is None:
down_label = ''
if up_label != '' and down_label != '':

View File

@ -70,21 +70,26 @@ class DagNode(object):
raise NotImplementedError()
DagEdge = namedtuple('DagEdge', ['downstream_node', 'downstream_label', 'upstream_node', 'upstream_label'])
DagEdge = namedtuple('DagEdge', ['downstream_node', 'downstream_label', 'upstream_node', 'upstream_label', 'upstream_selector'])
def get_incoming_edges(downstream_node, incoming_edge_map):
edges = []
for downstream_label, (upstream_node, upstream_label) in list(incoming_edge_map.items()):
edges += [DagEdge(downstream_node, downstream_label, upstream_node, upstream_label)]
# downstream_label, (upstream_node, upstream_label) in [(i[0], i[1][:2]) for i in self.incoming_edge_map.items()]
for downstream_label, upstream_info in [(i[0], i[1]) for i in incoming_edge_map.items()]:
upstream_node, upstream_label = upstream_info[:2]
upstream_selector = None if len(upstream_info) < 3 else upstream_info[2]
edges += [DagEdge(downstream_node, downstream_label, upstream_node, upstream_label, upstream_selector)]
return edges
def get_outgoing_edges(upstream_node, outgoing_edge_map):
edges = []
for upstream_label, downstream_infos in list(outgoing_edge_map.items()):
for (downstream_node, downstream_label) in downstream_infos:
edges += [DagEdge(downstream_node, downstream_label, upstream_node, upstream_label)]
for downstream_info in downstream_infos:
downstream_node, downstream_label = downstream_info[:2]
downstream_selector = None if len(downstream_info) < 3 else downstream_info[2]
edges += [DagEdge(downstream_node, downstream_label, upstream_node, upstream_label, downstream_selector)]
return edges
@ -155,21 +160,21 @@ def topo_sort(downstream_nodes):
sorted_nodes = []
outgoing_edge_maps = {}
def visit(upstream_node, upstream_label, downstream_node, downstream_label):
def visit(upstream_node, upstream_label, downstream_node, downstream_label, downstream_selector=None):
if upstream_node in marked_nodes:
raise RuntimeError('Graph is not a DAG')
if downstream_node is not None:
outgoing_edge_map = outgoing_edge_maps.get(upstream_node, {})
outgoing_edge_infos = outgoing_edge_map.get(upstream_label, [])
outgoing_edge_infos += [(downstream_node, downstream_label)]
outgoing_edge_infos += [(downstream_node, downstream_label, downstream_selector)]
outgoing_edge_map[upstream_label] = outgoing_edge_infos
outgoing_edge_maps[upstream_node] = outgoing_edge_map
if upstream_node not in sorted_nodes:
marked_nodes.append(upstream_node)
for edge in upstream_node.incoming_edges:
visit(edge.upstream_node, edge.upstream_label, edge.downstream_node, edge.downstream_label)
visit(edge.upstream_node, edge.upstream_label, edge.downstream_node, edge.downstream_label, edge.upstream_selector)
marked_nodes.remove(upstream_node)
sorted_nodes.append(upstream_node)