#1: refactor implementation to be simpler and more extensible

This commit is contained in:
Karl Kroening 2017-05-25 19:13:17 -10:00
parent dddf62869d
commit 4ada40b14c
2 changed files with 205 additions and 343 deletions

View File

@ -1,372 +1,234 @@
#!./venv/bin/python #!./venv/bin/python
from functools import partial
import hashlib import hashlib
import json import json
import operator import operator as _operator
import subprocess import subprocess
import sys
def _create_root_node(node_class, *args, **kwargs): class Node(object):
root = node_class(*args, **kwargs) def __init__(self, parents, name, *args, **kwargs):
root._update_hash() parent_hashes = [parent._hash for parent in parents]
return root
def _create_child_node(node_class, parent, *args, **kwargs):
child = node_class(parent, *args, **kwargs)
child._update_hash()
return child
class _Node(object):
def __init__(self, parents):
parent_hashes = [parent.hash for parent in parents]
assert len(parent_hashes) == len(set(parent_hashes)), 'Same node cannot be included as parent multiple times' assert len(parent_hashes) == len(set(parent_hashes)), 'Same node cannot be included as parent multiple times'
self.parents = parents self._parents = parents
self._name = name
self._args = args
self._kwargs = kwargs
self._update_hash()
@classmethod @classmethod
def _add_operator(cls, node_class): def _add_operator(cls, func):
if not getattr(node_class, 'STATIC', False): setattr(cls, func.__name__, func)
def func(self, *args, **kwargs):
return _create_child_node(node_class, self, *args, **kwargs)
setattr(cls, node_class.NAME, func)
@classmethod
def _add_operators(cls, node_classes):
[cls._add_operator(node_class) for node_class in node_classes]
@property
def _props(self):
return {k: v for k, v in self.__dict__.items() if k not in ['parents', 'hash']}
def __repr__(self): def __repr__(self):
# TODO: exclude default values. formatted_props = ['{}'.format(arg) for arg in self._args]
props = self._props formatted_props += ['{}={!r}'.format(key, self._kwargs[key]) for key in sorted(self._kwargs)]
formatted_props = ['{}={!r}'.format(key, props[key]) for key in sorted(self._props)] return '{}({})'.format(self._name, ','.join(formatted_props))
return '{}({})'.format(self.NAME, ','.join(formatted_props))
def __eq__(self, other): def __eq__(self, other):
return self.hash == other.hash return self._hash == other._hash
def _update_hash(self): def _update_hash(self):
my_hash = hashlib.md5(json.dumps(self._props)).hexdigest() props = {'args': self._args, 'kwargs': self._kwargs}
parent_hashes = [parent.hash for parent in self.parents] my_hash = hashlib.md5(json.dumps(props, sort_keys=True)).hexdigest()
parent_hashes = [parent._hash for parent in self._parents]
hashes = parent_hashes + [my_hash] hashes = parent_hashes + [my_hash]
self.hash = hashlib.md5(','.join(hashes)).hexdigest() self._hash = hashlib.md5(','.join(hashes)).hexdigest()
class _InputNode(_Node): class InputNode(Node):
def __init__(self, name, *args, **kwargs):
super(InputNode, self).__init__(parents=[], name=name, *args, **kwargs)
class FilterNode(Node):
def _get_filter(self):
params_text = self._name
arg_params = ['{}'.format(arg) for arg in self._args]
kwarg_params = ['{}={}'.format(k, self._kwargs[k]) for k in sorted(self._kwargs)]
params = arg_params + kwarg_params
if params:
params_text += '={}'.format(':'.join(params))
return params_text
class OutputNode(Node):
pass pass
class _FileInputNode(_InputNode): class GlobalNode(Node):
NAME = 'file_input' def __init__(self, parent, name, *args, **kwargs):
STATIC = True assert isinstance(parent, OutputNode), 'Global nodes can only be attached after output nodes'
super(GlobalNode, self).__init__([parent], name, *args, **kwargs)
def __init__(self, filename):
super(_FileInputNode, self).__init__(parents=[])
self.filename = filename
def operator(node_classes={Node}):
def decorator(func):
[node_class._add_operator(func) for node_class in node_classes]
return func
return decorator
class _FilterNode(_Node):
def _get_filter(self):
raise NotImplementedError()
def _get_params_from_dict(self, d): def file_input(filename):
params = "" return InputNode(file_input.__name__, filename=filename)
for k in self.kwargs:
params += k + "={}:".format(self.kwargs[k])
if len(params) > 0:
params = params[:-1]
return params
def _get_params_from_list(self, l):
return ":".join(["{}".format(i) for i in l])
def _get_filter_from_dict(self, d): @operator()
p = self._get_params_from_dict(d) def setpts(parent, expr):
if len(p) > 0: return FilterNode([parent], setpts.__name__, expr)
return self.NAME + "=" + p
return self.NAME
@operator()
def trim(parent, **kwargs):
return FilterNode([parent], trim.__name__, **kwargs)
class _TrimNode(_FilterNode):
NAME = 'trim'
def __init__(self, parent, **kwargs): @operator()
super(_TrimNode, self).__init__([parent]) def overlay(main_parent, overlay_parent, eof_action='repeat', **kwargs):
# if "setpts" not in kwargs: kwargs['eof_action'] = eof_action
# kwargs["setpts"] = 'PTS-STARTPTS' return FilterNode([main_parent, overlay_parent], overlay.__name__, **kwargs)
self.kwargs = kwargs
def _get_filter(self):
params = ""
for k in self.kwargs:
if k == "setpts":
continue
params += k
params += "={}:".format(self.kwargs[k])
if len(params) > 0:
params = params[:-1]
if "setpts" in self.kwargs: @operator()
params += "setpts={}".format(self.kwargs["setpts"]) def hflip(parent):
return FilterNode([parent], hflip.__name__)
return self.NAME + '=' + params
@operator()
def vflip(parent):
return FilterNode([parent], vflip.__name__)
class _OverlayNode(_FilterNode):
NAME = 'overlay'
def __init__(self, main_parent, overlay_parent, **kwargs): @operator()
super(_OverlayNode, self).__init__([main_parent, overlay_parent]) def drawbox(parent, x, y, width, height, color, thickness=None, **kwargs):
self.eof_action = eof_action if thickness:
self.kwargs = kwargs kwargs['t'] = thickness
return FilterNode([parent], drawbox.__name__, x, y, width, height, color, **kwargs)
def _get_filter(self):
return self._get_filter_from_dict(self.kwargs)
@operator()
def concat(*parents, **kwargs):
kwargs['n'] = len(parents)
return FilterNode(parents, concat.__name__, **kwargs)
class _HFlipNode(_FilterNode):
NAME = 'hflip'
def __init__(self, parent): @operator()
super(_HFlipNode, self).__init__([parent]) def zoompan(parent, **kwargs):
return FilterNode([parent], zoompan.__name__, **kwargs)
def _get_filter(self):
return self.NAME
@operator()
def hue(parent, **kwargs):
return FilterNode([parent], hue.__name__, **kwargs)
class _VFlipNode(_FilterNode):
NAME = 'vflip'
def __init__(self, parent): @operator()
super(_VFlipNode, self).__init__([parent]) def colorchannelmixer(parent, *args, **kwargs):
return FilterNode([parent], colorchannelmixer.__name__, **kwargs)
def _get_filter(self):
return self.NAME
@operator(node_classes={OutputNode, GlobalNode})
def overwrite_output(parent):
return GlobalNode(parent, overwrite_output.__name__)
class _DrawBoxNode(_FilterNode): @operator(node_classes={OutputNode})
NAME = 'drawbox' def merge_outputs(*parents):
return OutputNode(parents, merge_outputs.__name__)
def __init__(self, parent, x, y, width, height, color, **kwargs):
super(_DrawBoxNode, self).__init__([parent]) @operator(node_classes={InputNode, FilterNode})
self.x = x def file_output(parent, filename):
self.y = y return OutputNode([parent], file_output.__name__, filename=filename)
self.width = width
self.height = height
self.color = color def _get_stream_name(name):
self.kwargs = kwargs return '[{}]'.format(name)
def _get_filter(self):
f = 'drawbox={}:{}:{}:{}:{}'.format(self.x, self.y, self.width, self.height, self.color) def _get_input_args(input_node):
p = self._get_params_from_dict(self.kwargs) if input_node._name == file_input.__name__:
if len(p) > 0: args = ['-i', input_node._kwargs['filename']]
return f + ":" + p
return f
class _ConcatNode(_Node):
NAME = 'concat'
STATIC = True
def __init__(self, *parents):
super(_ConcatNode, self).__init__(parents)
def _get_filter(self):
return 'concat=n={}'.format(len(self.parents))
class _ZoomPanNode(_FilterNode):
NAME = 'zoompan'
def __init__(self, parent, **kwargs):
super(_ZoomPanNode, self).__init__([parent])
self.kwargs = kwargs
def _get_filter(self):
return self._get_filter_from_dict(self.kwargs)
class _HueNode(_FilterNode):
NAME = 'hue'
def __init__(self, parent, **kwargs):
super(_HueNode, self).__init__([parent])
self.kwargs = kwargs
def _get_filter(self):
return self._get_filter_from_dict(self.kwargs)
class _ColorChannelMixerNode(_FilterNode):
NAME = 'colorchannelmixer'
def __init__(self, parent, *args, **kwargs):
super(_ColorChannelMixerNode, self).__init__([parent])
self.args = args
self.kwargs = kwargs
def _get_filter(self):
f = self.NAME + "="
if self.args:
f += self._get_params_from_list(self.args)
if self.kwargs:
f += ":"
if self.kwargs:
f += self._get_params_from_dict(self.kwargs)
return f
class _OutputNode(_Node):
@classmethod
def _get_stream_name(cls, name):
return '[{}]'.format(name)
@classmethod
def _get_input_args(cls, input_node):
if isinstance(input_node, _FileInputNode):
args = ['-i', input_node.filename]
else:
assert False, 'Unsupported input node: {}'.format(input_node)
return args
@classmethod
def _topo_sort(cls, start_node):
marked_nodes = []
sorted_nodes = []
child_map = {}
def visit(node, child):
assert node not in marked_nodes, 'Graph is not a DAG'
if child is not None:
if node not in child_map:
child_map[node] = []
child_map[node].append(child)
if node not in sorted_nodes:
marked_nodes.append(node)
[visit(parent, node) for parent in node.parents]
marked_nodes.remove(node)
sorted_nodes.append(node)
unmarked_nodes = [start_node]
while unmarked_nodes:
visit(unmarked_nodes.pop(), None)
return sorted_nodes, child_map
@classmethod
def _get_filter_spec(cls, i, node, stream_name_map):
stream_name = cls._get_stream_name('v{}'.format(i))
stream_name_map[node] = stream_name
inputs = [stream_name_map[parent] for parent in node.parents]
filter_spec = '{}{}{}'.format(''.join(inputs), node._get_filter(), stream_name)
return filter_spec
@classmethod
def _get_filter_arg(cls, filter_nodes, stream_name_map):
filter_specs = [cls._get_filter_spec(i, node, stream_name_map) for i, node in enumerate(filter_nodes)]
return ';'.join(filter_specs)
@classmethod
def _get_global_args(cls, node):
if isinstance(node, _OverwriteOutputNode):
return ['-y']
else:
assert False, 'Unsupported global node: {}'.format(node)
@classmethod
def _get_output_args(cls, node, stream_name_map):
args = []
if not isinstance(node, _MergeOutputsNode):
stream_name = stream_name_map[node.parents[0]]
if stream_name != '[0]':
args += ['-map', stream_name]
if isinstance(node, _FileOutputNode):
args += [node.filename]
else:
assert False, 'Unsupported output node: {}'.format(node)
return args
def get_args(self):
args = []
# TODO: group nodes together, e.g. `-i somefile -r somerate`.
sorted_nodes, child_map = self._topo_sort(self)
input_nodes = [node for node in sorted_nodes if isinstance(node, _InputNode)]
output_nodes = [node for node in sorted_nodes if isinstance(node, _OutputNode) and not
isinstance(node, _GlobalNode)]
global_nodes = [node for node in sorted_nodes if isinstance(node, _GlobalNode)]
filter_nodes = [node for node in sorted_nodes if node not in (input_nodes + output_nodes + global_nodes)]
stream_name_map = {node: self._get_stream_name(i) for i, node in enumerate(input_nodes)}
filter_arg = self._get_filter_arg(filter_nodes, stream_name_map)
args += reduce(operator.add, [self._get_input_args(node) for node in input_nodes])
if filter_arg:
args += ['-filter_complex', filter_arg]
args += reduce(operator.add, [self._get_output_args(node, stream_name_map) for node in output_nodes])
args += reduce(operator.add, [self._get_global_args(node) for node in global_nodes], [])
return args
def run(self, cmd='ffmpeg'):
if type(cmd) == str:
cmd = [cmd]
args = cmd + self.get_args()
subprocess.check_call(args)
class _GlobalNode(_OutputNode):
def __init__(self, parent):
assert isinstance(parent, _OutputNode), 'Global nodes can only be attached after output nodes'
super(_GlobalNode, self).__init__([parent])
class _OverwriteOutputNode(_GlobalNode):
NAME = 'overwrite_output'
class _MergeOutputsNode(_OutputNode):
NAME = 'merge_outputs'
STATIC = True
def __init__(self, *parents):
assert not any([not isinstance(parent, _OutputNode) for parent in parents]), 'Can only merge output streams'
super(_MergeOutputsNode, self).__init__(*parents)
class _FileOutputNode(_OutputNode):
NAME = 'file_output'
def __init__(self, parent, filename):
super(_FileOutputNode, self).__init__([parent])
self.filename = filename
NODE_CLASSES = [
_HFlipNode,
_DrawBoxNode,
_ConcatNode,
_FileInputNode,
_FileOutputNode,
_OverlayNode,
_OverwriteOutputNode,
_TrimNode,
]
_Node._add_operators(NODE_CLASSES)
_module = sys.modules[__name__]
for _node_class in NODE_CLASSES:
if getattr(_node_class, 'STATIC', False):
func = _create_root_node
else: else:
func = _create_child_node assert False, 'Unsupported input node: {}'.format(input_node)
func = partial(func, _node_class) return args
setattr(_module, _node_class.NAME, func)
def get_args(node): def _topo_sort(start_node):
assert isinstance(node, _OutputNode), 'Cannot generate ffmpeg args for non-output node' marked_nodes = []
return node.get_args() sorted_nodes = []
child_map = {}
def visit(node, child):
assert node not in marked_nodes, 'Graph is not a DAG'
if child is not None:
if node not in child_map:
child_map[node] = []
child_map[node].append(child)
if node not in sorted_nodes:
marked_nodes.append(node)
[visit(parent, node) for parent in node._parents]
marked_nodes.remove(node)
sorted_nodes.append(node)
unmarked_nodes = [start_node]
while unmarked_nodes:
visit(unmarked_nodes.pop(), None)
return sorted_nodes, child_map
def run(node): def _get_filter_spec(i, node, stream_name_map):
assert isinstance(node, _OutputNode), 'Cannot run ffmpeg on non-output node' stream_name = _get_stream_name('v{}'.format(i))
return node.run() stream_name_map[node] = stream_name
inputs = [stream_name_map[parent] for parent in node._parents]
filter_spec = '{}{}{}'.format(''.join(inputs), node._get_filter(), stream_name)
return filter_spec
def _get_filter_arg(filter_nodes, stream_name_map):
filter_specs = [_get_filter_spec(i, node, stream_name_map) for i, node in enumerate(filter_nodes)]
return ';'.join(filter_specs)
def _get_global_args(node):
if node._name == overwrite_output.__name__:
return ['-y']
else:
assert False, 'Unsupported global node: {}'.format(node)
def _get_output_args(node, stream_name_map):
args = []
if node._name != merge_outputs.__name__:
stream_name = stream_name_map[node._parents[0]]
if stream_name != '[0]':
args += ['-map', stream_name]
if node._name == file_output.__name__:
args += [node._kwargs['filename']]
else:
assert False, 'Unsupported output node: {}'.format(node)
return args
@operator(node_classes={OutputNode, GlobalNode})
def get_args(parent):
args = []
# TODO: group nodes together, e.g. `-i somefile -r somerate`.
sorted_nodes, child_map = _topo_sort(parent)
input_nodes = [node for node in sorted_nodes if isinstance(node, InputNode)]
output_nodes = [node for node in sorted_nodes if isinstance(node, OutputNode) and not
isinstance(node, GlobalNode)]
global_nodes = [node for node in sorted_nodes if isinstance(node, GlobalNode)]
filter_nodes = [node for node in sorted_nodes if node not in (input_nodes + output_nodes + global_nodes)]
stream_name_map = {node: _get_stream_name(i) for i, node in enumerate(input_nodes)}
filter_arg = _get_filter_arg(filter_nodes, stream_name_map)
args += reduce(_operator.add, [_get_input_args(node) for node in input_nodes])
if filter_arg:
args += ['-filter_complex', filter_arg]
args += reduce(_operator.add, [_get_output_args(node, stream_name_map) for node in output_nodes])
args += reduce(_operator.add, [_get_global_args(node) for node in global_nodes], [])
return args
@operator(node_classes={OutputNode, GlobalNode})
def run(parent, cmd='ffmpeg'):
args = [cmd] + parent.get_args()
subprocess.check_call(args)

View File

@ -18,11 +18,11 @@ def test_fluent_equality():
base1 = ffmpeg.file_input('dummy1.mp4') base1 = ffmpeg.file_input('dummy1.mp4')
base2 = ffmpeg.file_input('dummy1.mp4') base2 = ffmpeg.file_input('dummy1.mp4')
base3 = ffmpeg.file_input('dummy2.mp4') base3 = ffmpeg.file_input('dummy2.mp4')
t1 = base1.trim(10, 20) t1 = base1.trim(start_frame=10, end_frame=20)
t2 = base1.trim(10, 20) t2 = base1.trim(start_frame=10, end_frame=20)
t3 = base1.trim(10, 30) t3 = base1.trim(start_frame=10, end_frame=30)
t4 = base2.trim(10, 20) t4 = base2.trim(start_frame=10, end_frame=20)
t5 = base3.trim(10, 20) t5 = base3.trim(start_frame=10, end_frame=20)
assert t1 == t2 assert t1 == t2
assert t1 != t3 assert t1 != t3
assert t1 == t4 assert t1 == t4
@ -31,9 +31,9 @@ def test_fluent_equality():
def test_fluent_concat(): def test_fluent_concat():
base = ffmpeg.file_input('dummy.mp4') base = ffmpeg.file_input('dummy.mp4')
trimmed1 = base.trim(10, 20) trimmed1 = base.trim(start_frame=10, end_frame=20)
trimmed2 = base.trim(30, 40) trimmed2 = base.trim(start_frame=30, end_frame=40)
trimmed3 = base.trim(50, 60) trimmed3 = base.trim(start_frame=50, end_frame=60)
concat1 = ffmpeg.concat(trimmed1, trimmed2, trimmed3) concat1 = ffmpeg.concat(trimmed1, trimmed2, trimmed3)
concat2 = ffmpeg.concat(trimmed1, trimmed2, trimmed3) concat2 = ffmpeg.concat(trimmed1, trimmed2, trimmed3)
concat3 = ffmpeg.concat(trimmed1, trimmed3, trimmed2) concat3 = ffmpeg.concat(trimmed1, trimmed3, trimmed2)
@ -47,7 +47,7 @@ def test_fluent_concat():
def test_fluent_output(): def test_fluent_output():
ffmpeg \ ffmpeg \
.file_input('dummy.mp4') \ .file_input('dummy.mp4') \
.trim(10, 20) \ .trim(start_frame=10, end_frame=20) \
.file_output('dummy2.mp4') .file_output('dummy2.mp4')
@ -55,25 +55,25 @@ def test_fluent_complex_filter():
in_file = ffmpeg.file_input('dummy.mp4') in_file = ffmpeg.file_input('dummy.mp4')
return ffmpeg \ return ffmpeg \
.concat( .concat(
in_file.trim(10, 20), in_file.trim(start_frame=10, end_frame=20),
in_file.trim(30, 40), in_file.trim(start_frame=30, end_frame=40),
in_file.trim(50, 60) in_file.trim(start_frame=50, end_frame=60)
) \ ) \
.file_output('dummy2.mp4') .file_output('dummy2.mp4')
def test_repr(): def test_repr():
in_file = ffmpeg.file_input('dummy.mp4') in_file = ffmpeg.file_input('dummy.mp4')
trim1 = ffmpeg.trim(in_file, 10, 20) trim1 = ffmpeg.trim(in_file, start_frame=10, end_frame=20)
trim2 = ffmpeg.trim(in_file, 30, 40) trim2 = ffmpeg.trim(in_file, start_frame=30, end_frame=40)
trim3 = ffmpeg.trim(in_file, 50, 60) trim3 = ffmpeg.trim(in_file, start_frame=50, end_frame=60)
concatted = ffmpeg.concat(trim1, trim2, trim3) concatted = ffmpeg.concat(trim1, trim2, trim3)
output = ffmpeg.file_output(concatted, 'dummy2.mp4') output = ffmpeg.file_output(concatted, 'dummy2.mp4')
assert repr(in_file) == "file_input(filename='dummy.mp4')" assert repr(in_file) == "file_input(filename='dummy.mp4')"
assert repr(trim1) == "trim(end_frame=20,setpts='PTS-STARTPTS',start_frame=10)" assert repr(trim1) == "trim(end_frame=20,start_frame=10)"
assert repr(trim2) == "trim(end_frame=40,setpts='PTS-STARTPTS',start_frame=30)" assert repr(trim2) == "trim(end_frame=40,start_frame=30)"
assert repr(trim3) == "trim(end_frame=60,setpts='PTS-STARTPTS',start_frame=50)" assert repr(trim3) == "trim(end_frame=60,start_frame=50)"
assert repr(concatted) == "concat()" assert repr(concatted) == "concat(n=3)"
assert repr(output) == "file_output(filename='dummy2.mp4')" assert repr(output) == "file_output(filename='dummy2.mp4')"
@ -87,8 +87,8 @@ def _get_complex_filter_example():
overlay_file = ffmpeg.file_input(TEST_OVERLAY_FILE) overlay_file = ffmpeg.file_input(TEST_OVERLAY_FILE)
return ffmpeg \ return ffmpeg \
.concat( .concat(
in_file.trim(10, 20), in_file.trim(start_frame=10, end_frame=20),
in_file.trim(30, 40), in_file.trim(start_frame=30, end_frame=40),
) \ ) \
.overlay(overlay_file.hflip()) \ .overlay(overlay_file.hflip()) \
.drawbox(50, 50, 120, 120, color='red', thickness=5) \ .drawbox(50, 50, 120, 120, color='red', thickness=5) \
@ -103,8 +103,8 @@ def test_get_args_complex_filter():
'-i', TEST_INPUT_FILE, '-i', TEST_INPUT_FILE,
'-i', TEST_OVERLAY_FILE, '-i', TEST_OVERLAY_FILE,
'-filter_complex', '-filter_complex',
'[0]trim=start_frame=10:end_frame=20,setpts=PTS-STARTPTS[v0];' \ '[0]trim=end_frame=20:start_frame=10[v0];' \
'[0]trim=start_frame=30:end_frame=40,setpts=PTS-STARTPTS[v1];' \ '[0]trim=end_frame=40:start_frame=30[v1];' \
'[v0][v1]concat=n=2[v2];' \ '[v0][v1]concat=n=2[v2];' \
'[1]hflip[v3];' \ '[1]hflip[v3];' \
'[v2][v3]overlay=eof_action=repeat[v4];' \ '[v2][v3]overlay=eof_action=repeat[v4];' \