#17: move graph stuff to dag.py; add edge labelling

This commit is contained in:
Karl Kroening 2017-07-05 03:13:30 -06:00
parent fc946be164
commit 6a9a12e718
4 changed files with 206 additions and 131 deletions

View File

@ -1,5 +1,6 @@
from __future__ import unicode_literals
from .dag import topo_sort
from functools import reduce
from past.builtins import basestring
import copy
@ -34,8 +35,8 @@ def _convert_kwargs_to_cmd_line_args(kwargs):
def _get_input_args(input_node):
if input_node._name == input.__name__:
kwargs = copy.copy(input_node._kwargs)
if input_node.name == input.__name__:
kwargs = copy.copy(input_node.kwargs)
filename = kwargs.pop('filename')
fmt = kwargs.pop('format', None)
video_size = kwargs.pop('video_size', None)
@ -51,27 +52,6 @@ def _get_input_args(input_node):
return args
def _topo_sort(start_node):
marked_nodes = []
sorted_nodes = []
child_map = {}
def visit(node, child):
assert node not in marked_nodes, 'Graph is not a DAG'
if child is not None:
if node not in child_map:
child_map[node] = []
child_map[node].append(child)
if node not in sorted_nodes:
marked_nodes.append(node)
[visit(parent, node) for parent in node._parents]
marked_nodes.remove(node)
sorted_nodes.append(node)
unmarked_nodes = [start_node]
while unmarked_nodes:
visit(unmarked_nodes.pop(), None)
return sorted_nodes, child_map
def _get_filter_spec(i, node, stream_name_map):
stream_name = _get_stream_name('v{}'.format(i))
stream_name_map[node] = stream_name
@ -86,7 +66,7 @@ def _get_filter_arg(filter_nodes, stream_name_map):
def _get_global_args(node):
if node._name == overwrite_output.__name__:
if node.name == overwrite_output.__name__:
return ['-y']
else:
assert False, 'Unsupported global node: {}'.format(node)
@ -94,12 +74,12 @@ def _get_global_args(node):
def _get_output_args(node, stream_name_map):
args = []
if node._name != merge_outputs.__name__:
if node.name != merge_outputs.__name__:
stream_name = stream_name_map[node._parents[0]]
if stream_name != '[0]':
args += ['-map', stream_name]
if node._name == output.__name__:
kwargs = copy.copy(node._kwargs)
if node.name == output.__name__:
kwargs = copy.copy(node.kwargs)
filename = kwargs.pop('filename')
fmt = kwargs.pop('format', None)
if fmt:
@ -116,7 +96,7 @@ def get_args(node):
"""Get command-line arguments for ffmpeg."""
args = []
# TODO: group nodes together, e.g. `-i somefile -r somerate`.
sorted_nodes, child_map = _topo_sort(node)
sorted_nodes, child_map = topo_sort([node])
del(node)
input_nodes = [node for node in sorted_nodes if isinstance(node, InputNode)]
output_nodes = [node for node in sorted_nodes if isinstance(node, OutputNode) and not

178
ffmpeg/dag.py Normal file
View File

@ -0,0 +1,178 @@
from builtins import object
from collections import namedtuple
import copy
import hashlib
def _recursive_repr(item):
"""Hack around python `repr` to deterministically represent dictionaries.
This is able to represent more things than json.dumps, since it does not require things to be JSON serializable
(e.g. datetimes).
"""
if isinstance(item, basestring):
result = str(item)
elif isinstance(item, list):
result = '[{}]'.format(', '.join([_recursive_repr(x) for x in item]))
elif isinstance(item, dict):
kv_pairs = ['{}: {}'.format(_recursive_repr(k), _recursive_repr(item[k])) for k in sorted(item)]
result = '{' + ', '.join(kv_pairs) + '}'
else:
result = repr(item)
return result
def _get_hash(item):
hasher = hashlib.sha224()
repr_ = _recursive_repr(item)
hasher.update(repr_.encode('utf-8'))
return hasher.hexdigest()
class DagNode(object):
"""Node in a directed-acyclic graph (DAG).
Edges:
DagNodes are connected by edges. An edge connects two nodes with a label for each side:
- ``upstream_node``: upstream/parent node
- ``upstream_label``: label on the outgoing side of the upstream node
- ``downstream_node``: downstream/child node
- ``downstream_label``: label on the incoming side of the downstream node
For example, DagNode A may be connected to DagNode B with an edge labelled "foo" on A's side, and "bar" on B's
side:
_____ _____
| | | |
| A >[foo]---[bar]> B |
|_____| |_____|
Edge labels may be integers or strings, and nodes cannot have more than one incoming edge with the same label.
DagNodes may have any number of incoming edges and any number of outgoing edges. DagNodes keep track only of
their incoming edges, but the entire graph structure can be inferred by looking at the furthest downstream
nodes and working backwards.
Hashing:
DagNodes must be hashable, and two nodes are considered to be equivalent if they have the same hash value.
Nodes are immutable, and the hash should remain constant as a result. If a node with new contents is required,
create a new node and throw the old one away.
String representation:
In order for graph visualization tools to show useful information, nodes must be representable as strings. The
``repr`` operator should provide a more or less "full" representation of the node, and the ``short_repr``
property should be a shortened, concise representation.
Again, because nodes are immutable, the string representations should remain constant.
"""
def __hash__(self):
"""Return an integer hash of the node."""
raise NotImplementedError()
def __eq__(self, other):
"""Compare two nodes; implementations should return True if (and only if) hashes match."""
raise NotImplementedError()
def __repr__(self, other):
"""Return a full string representation of the node."""
raise NotImplementedError()
@property
def short_repr(self):
"""Return a partial/concise representation of the node."""
raise NotImplementedError()
@property
def incoming_edge_map(self):
"""Provides information about all incoming edges that connect to this node.
The edge map is a dictionary that maps an ``incoming_label`` to ``(outgoing_node, outgoing_label)``. Note that
implicity, ``incoming_node`` is ``self``. See "Edges" section above.
"""
raise NotImplementedError()
DagEdge = namedtuple('DagEdge', ['downstream_node', 'downstream_label', 'upstream_node', 'upstream_label'])
class KwargReprNode(DagNode):
"""A DagNode that can be represented as a set of args+kwargs.
"""
def __get_hash(self):
hashes = self.__upstream_hashes + [self.__inner_hash]
hash_strs = [str(x) for x in hashes]
hashes_str = ','.join(hash_strs).encode('utf-8')
hash_str = hashlib.md5(hashes_str).hexdigest()
return int(hash_str, base=16)
def __init__(self, incoming_edge_map, name, args, kwargs):
self.__incoming_edge_map = incoming_edge_map
self.name = name
self.args = args
self.kwargs = kwargs
self.__hash = self.__get_hash()
@property
def __upstream_hashes(self):
hashes = []
for downstream_label, (upstream_node, upstream_label) in self.incoming_edge_map.items():
hashes += [hash(x) for x in [downstream_label, upstream_node, upstream_label]]
return hashes
@property
def __inner_hash(self):
props = {'args': self.args, 'kwargs': self.kwargs}
return _get_hash(props)
def __hash__(self):
return self.__hash
def __eq__(self, other):
return hash(self) == hash(other)
@property
def short_hash(self):
return '{:x}'.format(abs(hash(self)))[:12]
def __repr__(self):
formatted_props = ['{!r}'.format(arg) for arg in self.args]
formatted_props += ['{}={!r}'.format(key, self.kwargs[key]) for key in sorted(self.kwargs)]
return '{}({}) <{}>'.format(self.name, ', '.join(formatted_props), self.short_hash)
@property
def incoming_edges(self):
edges = []
for downstream_label, (upstream_node, upstream_label) in self.incoming_edge_map.items():
downstream_node = self
edges += [DagEdge(downstream_node, downstream_label, upstream_node, upstream_label)]
return edges
@property
def incoming_edge_map(self):
return self.__incoming_edge_map
@property
def short_repr(self):
return self.name
def topo_sort(start_nodes):
marked_nodes = []
sorted_nodes = []
child_map = {}
def visit(node, child):
assert node not in marked_nodes, 'Graph is not a DAG'
if child is not None:
if node not in child_map:
child_map[node] = []
child_map[node].append(child)
if node not in sorted_nodes:
marked_nodes.append(node)
[visit(parent, node) for parent in node._parents]
marked_nodes.remove(node)
sorted_nodes.append(node)
unmarked_nodes = list(copy.copy(start_nodes))
while unmarked_nodes:
visit(unmarked_nodes.pop(), None)
return sorted_nodes, child_map

View File

@ -1,105 +1,22 @@
from __future__ import unicode_literals
from builtins import object
import copy
import hashlib
from .dag import KwargReprNode
def _recursive_repr(item):
"""Hack around python `repr` to deterministically represent dictionaries.
This is able to represent more things than json.dumps, since it does not require things to be JSON serializable
(e.g. datetimes).
"""
if isinstance(item, basestring):
result = str(item)
elif isinstance(item, list):
result = '[{}]'.format(', '.join([_recursive_repr(x) for x in item]))
elif isinstance(item, dict):
kv_pairs = ['{}: {}'.format(_recursive_repr(k), _recursive_repr(item[k])) for k in sorted(item)]
result = '{' + ', '.join(kv_pairs) + '}'
else:
result = repr(item)
return result
def _create_hash(item):
hasher = hashlib.sha224()
repr_ = _recursive_repr(item)
hasher.update(repr_.encode('utf-8'))
return hasher.hexdigest()
class _NodeBase(object):
@property
def hash(self):
if self._hash is None:
self._update_hash()
return self._hash
def __init__(self, parents, name):
parent_hashes = [hash(parent) for parent in parents]
assert len(parent_hashes) == len(set(parent_hashes)), 'Same node cannot be included as parent multiple times'
self._parents = parents
self._hash = None
self._name = name
def _transplant(self, new_parents):
other = copy.copy(self)
other._parents = copy.copy(new_parents)
return other
@property
def _repr_args(self):
raise NotImplementedError()
@property
def _repr_kwargs(self):
raise NotImplementedError()
@property
def _short_hash(self):
return '{:x}'.format(abs(hash(self)))[:12]
def __repr__(self):
args = self._repr_args
kwargs = self._repr_kwargs
formatted_props = ['{!r}'.format(arg) for arg in args]
formatted_props += ['{}={!r}'.format(key, kwargs[key]) for key in sorted(kwargs)]
return '{}({}) <{}>'.format(self._name, ', '.join(formatted_props), self._short_hash)
def __hash__(self):
if self._hash is None:
self._update_hash()
return self._hash
def __eq__(self, other):
return hash(self) == hash(other)
def _update_hash(self):
props = {'args': self._repr_args, 'kwargs': self._repr_kwargs}
my_hash = _create_hash(props)
parent_hashes = [str(hash(parent)) for parent in self._parents]
hashes = parent_hashes + [my_hash]
hashes_str = ','.join(hashes).encode('utf-8')
hash_str = hashlib.md5(hashes_str).hexdigest()
self._hash = int(hash_str, base=16)
class Node(_NodeBase):
class Node(KwargReprNode):
"""Node base"""
def __init__(self, parents, name, *args, **kwargs):
super(Node, self).__init__(parents, name)
self._args = args
self._kwargs = kwargs
incoming_edge_map = {}
for downstream_label, parent in enumerate(parents):
upstream_label = 0 # assume nodes have a single output (FIXME)
upstream_node = parent
incoming_edge_map[downstream_label] = (upstream_node, upstream_label)
super(Node, self).__init__(incoming_edge_map, name, args, kwargs)
@property
def _repr_args(self):
return self._args
@property
def _repr_kwargs(self):
return self._kwargs
def _parents(self):
# TODO: change graph compilation to use `self.incoming_edges` instead.
return [edge.upstream_node for edge in self.incoming_edges]
class InputNode(Node):
@ -111,9 +28,9 @@ class InputNode(Node):
class FilterNode(Node):
"""FilterNode"""
def _get_filter(self):
params_text = self._name
arg_params = ['{}'.format(arg) for arg in self._args]
kwarg_params = ['{}={}'.format(k, self._kwargs[k]) for k in sorted(self._kwargs)]
params_text = self.name
arg_params = ['{}'.format(arg) for arg in self.args]
kwarg_params = ['{}={}'.format(k, self.kwargs[k]) for k in sorted(self.kwargs)]
params = arg_params + kwarg_params
if params:
params_text += '={}'.format(':'.join(params))

View File

@ -73,12 +73,12 @@ def test_repr():
trim3 = ffmpeg.trim(in_file, start_frame=50, end_frame=60)
concatted = ffmpeg.concat(trim1, trim2, trim3)
output = ffmpeg.output(concatted, 'dummy2.mp4')
assert repr(in_file) == "input(filename={!r}) <{}>".format('dummy.mp4', in_file._short_hash)
assert repr(trim1) == "trim(end_frame=20, start_frame=10) <{}>".format(trim1._short_hash)
assert repr(trim2) == "trim(end_frame=40, start_frame=30) <{}>".format(trim2._short_hash)
assert repr(trim3) == "trim(end_frame=60, start_frame=50) <{}>".format(trim3._short_hash)
assert repr(concatted) == "concat(n=3) <{}>".format(concatted._short_hash)
assert repr(output) == "output(filename={!r}) <{}>".format('dummy2.mp4', output._short_hash)
assert repr(in_file) == "input(filename={!r}) <{}>".format('dummy.mp4', in_file.short_hash)
assert repr(trim1) == "trim(end_frame=20, start_frame=10) <{}>".format(trim1.short_hash)
assert repr(trim2) == "trim(end_frame=40, start_frame=30) <{}>".format(trim2.short_hash)
assert repr(trim3) == "trim(end_frame=60, start_frame=50) <{}>".format(trim3.short_hash)
assert repr(concatted) == "concat(n=3) <{}>".format(concatted.short_hash)
assert repr(output) == "output(filename={!r}) <{}>".format('dummy2.mp4', output.short_hash)
def test_get_args_simple():