From 4ada40b14cd702609a51761ca31b04a982d350d8 Mon Sep 17 00:00:00 2001 From: Karl Kroening Date: Thu, 25 May 2017 19:13:17 -1000 Subject: [PATCH] #1: refactor implementation to be simpler and more extensible --- ffmpeg/__init__.py | 502 +++++++++++++----------------------- ffmpeg/tests/test_ffmpeg.py | 46 ++-- 2 files changed, 205 insertions(+), 343 deletions(-) diff --git a/ffmpeg/__init__.py b/ffmpeg/__init__.py index 486987d..241c828 100755 --- a/ffmpeg/__init__.py +++ b/ffmpeg/__init__.py @@ -1,372 +1,234 @@ #!./venv/bin/python -from functools import partial import hashlib import json -import operator +import operator as _operator import subprocess -import sys -def _create_root_node(node_class, *args, **kwargs): - root = node_class(*args, **kwargs) - root._update_hash() - return root - - -def _create_child_node(node_class, parent, *args, **kwargs): - child = node_class(parent, *args, **kwargs) - child._update_hash() - return child - - -class _Node(object): - def __init__(self, parents): - parent_hashes = [parent.hash for parent in parents] +class Node(object): + def __init__(self, parents, name, *args, **kwargs): + parent_hashes = [parent._hash for parent in parents] assert len(parent_hashes) == len(set(parent_hashes)), 'Same node cannot be included as parent multiple times' - self.parents = parents + self._parents = parents + self._name = name + self._args = args + self._kwargs = kwargs + self._update_hash() @classmethod - def _add_operator(cls, node_class): - if not getattr(node_class, 'STATIC', False): - def func(self, *args, **kwargs): - return _create_child_node(node_class, self, *args, **kwargs) - setattr(cls, node_class.NAME, func) - - @classmethod - def _add_operators(cls, node_classes): - [cls._add_operator(node_class) for node_class in node_classes] - - @property - def _props(self): - return {k: v for k, v in self.__dict__.items() if k not in ['parents', 'hash']} + def _add_operator(cls, func): + setattr(cls, func.__name__, func) def __repr__(self): - # TODO: exclude default values. - props = self._props - formatted_props = ['{}={!r}'.format(key, props[key]) for key in sorted(self._props)] - return '{}({})'.format(self.NAME, ','.join(formatted_props)) + formatted_props = ['{}'.format(arg) for arg in self._args] + formatted_props += ['{}={!r}'.format(key, self._kwargs[key]) for key in sorted(self._kwargs)] + return '{}({})'.format(self._name, ','.join(formatted_props)) def __eq__(self, other): - return self.hash == other.hash + return self._hash == other._hash def _update_hash(self): - my_hash = hashlib.md5(json.dumps(self._props)).hexdigest() - parent_hashes = [parent.hash for parent in self.parents] + props = {'args': self._args, 'kwargs': self._kwargs} + my_hash = hashlib.md5(json.dumps(props, sort_keys=True)).hexdigest() + parent_hashes = [parent._hash for parent in self._parents] hashes = parent_hashes + [my_hash] - self.hash = hashlib.md5(','.join(hashes)).hexdigest() + self._hash = hashlib.md5(','.join(hashes)).hexdigest() -class _InputNode(_Node): +class InputNode(Node): + def __init__(self, name, *args, **kwargs): + super(InputNode, self).__init__(parents=[], name=name, *args, **kwargs) + + +class FilterNode(Node): + def _get_filter(self): + params_text = self._name + arg_params = ['{}'.format(arg) for arg in self._args] + kwarg_params = ['{}={}'.format(k, self._kwargs[k]) for k in sorted(self._kwargs)] + params = arg_params + kwarg_params + if params: + params_text += '={}'.format(':'.join(params)) + return params_text + + +class OutputNode(Node): pass -class _FileInputNode(_InputNode): - NAME = 'file_input' - STATIC = True +class GlobalNode(Node): + def __init__(self, parent, name, *args, **kwargs): + assert isinstance(parent, OutputNode), 'Global nodes can only be attached after output nodes' + super(GlobalNode, self).__init__([parent], name, *args, **kwargs) - def __init__(self, filename): - super(_FileInputNode, self).__init__(parents=[]) - self.filename = filename +def operator(node_classes={Node}): + def decorator(func): + [node_class._add_operator(func) for node_class in node_classes] + return func + return decorator -class _FilterNode(_Node): - def _get_filter(self): - raise NotImplementedError() - def _get_params_from_dict(self, d): - params = "" - for k in self.kwargs: - params += k + "={}:".format(self.kwargs[k]) - if len(params) > 0: - params = params[:-1] - return params +def file_input(filename): + return InputNode(file_input.__name__, filename=filename) - def _get_params_from_list(self, l): - return ":".join(["{}".format(i) for i in l]) - def _get_filter_from_dict(self, d): - p = self._get_params_from_dict(d) - if len(p) > 0: - return self.NAME + "=" + p - return self.NAME +@operator() +def setpts(parent, expr): + return FilterNode([parent], setpts.__name__, expr) +@operator() +def trim(parent, **kwargs): + return FilterNode([parent], trim.__name__, **kwargs) -class _TrimNode(_FilterNode): - NAME = 'trim' - def __init__(self, parent, **kwargs): - super(_TrimNode, self).__init__([parent]) - # if "setpts" not in kwargs: - # kwargs["setpts"] = 'PTS-STARTPTS' - self.kwargs = kwargs +@operator() +def overlay(main_parent, overlay_parent, eof_action='repeat', **kwargs): + kwargs['eof_action'] = eof_action + return FilterNode([main_parent, overlay_parent], overlay.__name__, **kwargs) - def _get_filter(self): - params = "" - for k in self.kwargs: - if k == "setpts": - continue - params += k - params += "={}:".format(self.kwargs[k]) - if len(params) > 0: - params = params[:-1] - if "setpts" in self.kwargs: - params += "setpts={}".format(self.kwargs["setpts"]) +@operator() +def hflip(parent): + return FilterNode([parent], hflip.__name__) - return self.NAME + '=' + params +@operator() +def vflip(parent): + return FilterNode([parent], vflip.__name__) -class _OverlayNode(_FilterNode): - NAME = 'overlay' - def __init__(self, main_parent, overlay_parent, **kwargs): - super(_OverlayNode, self).__init__([main_parent, overlay_parent]) - self.eof_action = eof_action - self.kwargs = kwargs +@operator() +def drawbox(parent, x, y, width, height, color, thickness=None, **kwargs): + if thickness: + kwargs['t'] = thickness + return FilterNode([parent], drawbox.__name__, x, y, width, height, color, **kwargs) - def _get_filter(self): - return self._get_filter_from_dict(self.kwargs) +@operator() +def concat(*parents, **kwargs): + kwargs['n'] = len(parents) + return FilterNode(parents, concat.__name__, **kwargs) -class _HFlipNode(_FilterNode): - NAME = 'hflip' - def __init__(self, parent): - super(_HFlipNode, self).__init__([parent]) +@operator() +def zoompan(parent, **kwargs): + return FilterNode([parent], zoompan.__name__, **kwargs) - def _get_filter(self): - return self.NAME +@operator() +def hue(parent, **kwargs): + return FilterNode([parent], hue.__name__, **kwargs) -class _VFlipNode(_FilterNode): - NAME = 'vflip' - def __init__(self, parent): - super(_VFlipNode, self).__init__([parent]) +@operator() +def colorchannelmixer(parent, *args, **kwargs): + return FilterNode([parent], colorchannelmixer.__name__, **kwargs) - def _get_filter(self): - return self.NAME - - - -class _DrawBoxNode(_FilterNode): - NAME = 'drawbox' - - def __init__(self, parent, x, y, width, height, color, **kwargs): - super(_DrawBoxNode, self).__init__([parent]) - self.x = x - self.y = y - self.width = width - self.height = height - self.color = color - self.kwargs = kwargs - - def _get_filter(self): - f = 'drawbox={}:{}:{}:{}:{}'.format(self.x, self.y, self.width, self.height, self.color) - p = self._get_params_from_dict(self.kwargs) - if len(p) > 0: - return f + ":" + p - return f - - -class _ConcatNode(_Node): - NAME = 'concat' - STATIC = True - - def __init__(self, *parents): - super(_ConcatNode, self).__init__(parents) - - def _get_filter(self): - return 'concat=n={}'.format(len(self.parents)) - -class _ZoomPanNode(_FilterNode): - NAME = 'zoompan' - - def __init__(self, parent, **kwargs): - super(_ZoomPanNode, self).__init__([parent]) - self.kwargs = kwargs - - def _get_filter(self): - return self._get_filter_from_dict(self.kwargs) - -class _HueNode(_FilterNode): - NAME = 'hue' - - def __init__(self, parent, **kwargs): - super(_HueNode, self).__init__([parent]) - self.kwargs = kwargs - - def _get_filter(self): - return self._get_filter_from_dict(self.kwargs) - -class _ColorChannelMixerNode(_FilterNode): - NAME = 'colorchannelmixer' - - def __init__(self, parent, *args, **kwargs): - super(_ColorChannelMixerNode, self).__init__([parent]) - self.args = args - self.kwargs = kwargs - - def _get_filter(self): - f = self.NAME + "=" - if self.args: - f += self._get_params_from_list(self.args) - if self.kwargs: - f += ":" - if self.kwargs: - f += self._get_params_from_dict(self.kwargs) - return f - - -class _OutputNode(_Node): - @classmethod - def _get_stream_name(cls, name): - return '[{}]'.format(name) - - @classmethod - def _get_input_args(cls, input_node): - if isinstance(input_node, _FileInputNode): - args = ['-i', input_node.filename] - else: - assert False, 'Unsupported input node: {}'.format(input_node) - return args - - @classmethod - def _topo_sort(cls, start_node): - marked_nodes = [] - sorted_nodes = [] - child_map = {} - def visit(node, child): - assert node not in marked_nodes, 'Graph is not a DAG' - if child is not None: - if node not in child_map: - child_map[node] = [] - child_map[node].append(child) - if node not in sorted_nodes: - marked_nodes.append(node) - [visit(parent, node) for parent in node.parents] - marked_nodes.remove(node) - sorted_nodes.append(node) - unmarked_nodes = [start_node] - while unmarked_nodes: - visit(unmarked_nodes.pop(), None) - return sorted_nodes, child_map - - @classmethod - def _get_filter_spec(cls, i, node, stream_name_map): - stream_name = cls._get_stream_name('v{}'.format(i)) - stream_name_map[node] = stream_name - inputs = [stream_name_map[parent] for parent in node.parents] - filter_spec = '{}{}{}'.format(''.join(inputs), node._get_filter(), stream_name) - return filter_spec - - @classmethod - def _get_filter_arg(cls, filter_nodes, stream_name_map): - filter_specs = [cls._get_filter_spec(i, node, stream_name_map) for i, node in enumerate(filter_nodes)] - return ';'.join(filter_specs) - - @classmethod - def _get_global_args(cls, node): - if isinstance(node, _OverwriteOutputNode): - return ['-y'] - else: - assert False, 'Unsupported global node: {}'.format(node) - - @classmethod - def _get_output_args(cls, node, stream_name_map): - args = [] - if not isinstance(node, _MergeOutputsNode): - stream_name = stream_name_map[node.parents[0]] - if stream_name != '[0]': - args += ['-map', stream_name] - if isinstance(node, _FileOutputNode): - args += [node.filename] - else: - assert False, 'Unsupported output node: {}'.format(node) - return args - - def get_args(self): - args = [] - # TODO: group nodes together, e.g. `-i somefile -r somerate`. - sorted_nodes, child_map = self._topo_sort(self) - input_nodes = [node for node in sorted_nodes if isinstance(node, _InputNode)] - output_nodes = [node for node in sorted_nodes if isinstance(node, _OutputNode) and not - isinstance(node, _GlobalNode)] - global_nodes = [node for node in sorted_nodes if isinstance(node, _GlobalNode)] - filter_nodes = [node for node in sorted_nodes if node not in (input_nodes + output_nodes + global_nodes)] - stream_name_map = {node: self._get_stream_name(i) for i, node in enumerate(input_nodes)} - filter_arg = self._get_filter_arg(filter_nodes, stream_name_map) - args += reduce(operator.add, [self._get_input_args(node) for node in input_nodes]) - if filter_arg: - args += ['-filter_complex', filter_arg] - args += reduce(operator.add, [self._get_output_args(node, stream_name_map) for node in output_nodes]) - args += reduce(operator.add, [self._get_global_args(node) for node in global_nodes], []) - return args - - def run(self, cmd='ffmpeg'): - if type(cmd) == str: - cmd = [cmd] - args = cmd + self.get_args() - subprocess.check_call(args) - - -class _GlobalNode(_OutputNode): - def __init__(self, parent): - assert isinstance(parent, _OutputNode), 'Global nodes can only be attached after output nodes' - super(_GlobalNode, self).__init__([parent]) - - -class _OverwriteOutputNode(_GlobalNode): - NAME = 'overwrite_output' - - -class _MergeOutputsNode(_OutputNode): - NAME = 'merge_outputs' - STATIC = True - - def __init__(self, *parents): - assert not any([not isinstance(parent, _OutputNode) for parent in parents]), 'Can only merge output streams' - super(_MergeOutputsNode, self).__init__(*parents) - - -class _FileOutputNode(_OutputNode): - NAME = 'file_output' - - def __init__(self, parent, filename): - super(_FileOutputNode, self).__init__([parent]) - self.filename = filename - - -NODE_CLASSES = [ - _HFlipNode, - _DrawBoxNode, - _ConcatNode, - _FileInputNode, - _FileOutputNode, - _OverlayNode, - _OverwriteOutputNode, - _TrimNode, -] - -_Node._add_operators(NODE_CLASSES) - -_module = sys.modules[__name__] -for _node_class in NODE_CLASSES: - if getattr(_node_class, 'STATIC', False): - func = _create_root_node + +@operator(node_classes={OutputNode, GlobalNode}) +def overwrite_output(parent): + return GlobalNode(parent, overwrite_output.__name__) + + +@operator(node_classes={OutputNode}) +def merge_outputs(*parents): + return OutputNode(parents, merge_outputs.__name__) + + +@operator(node_classes={InputNode, FilterNode}) +def file_output(parent, filename): + return OutputNode([parent], file_output.__name__, filename=filename) + + +def _get_stream_name(name): + return '[{}]'.format(name) + + +def _get_input_args(input_node): + if input_node._name == file_input.__name__: + args = ['-i', input_node._kwargs['filename']] else: - func = _create_child_node - func = partial(func, _node_class) - setattr(_module, _node_class.NAME, func) + assert False, 'Unsupported input node: {}'.format(input_node) + return args -def get_args(node): - assert isinstance(node, _OutputNode), 'Cannot generate ffmpeg args for non-output node' - return node.get_args() +def _topo_sort(start_node): + marked_nodes = [] + sorted_nodes = [] + child_map = {} + def visit(node, child): + assert node not in marked_nodes, 'Graph is not a DAG' + if child is not None: + if node not in child_map: + child_map[node] = [] + child_map[node].append(child) + if node not in sorted_nodes: + marked_nodes.append(node) + [visit(parent, node) for parent in node._parents] + marked_nodes.remove(node) + sorted_nodes.append(node) + unmarked_nodes = [start_node] + while unmarked_nodes: + visit(unmarked_nodes.pop(), None) + return sorted_nodes, child_map -def run(node): - assert isinstance(node, _OutputNode), 'Cannot run ffmpeg on non-output node' - return node.run() +def _get_filter_spec(i, node, stream_name_map): + stream_name = _get_stream_name('v{}'.format(i)) + stream_name_map[node] = stream_name + inputs = [stream_name_map[parent] for parent in node._parents] + filter_spec = '{}{}{}'.format(''.join(inputs), node._get_filter(), stream_name) + return filter_spec + + +def _get_filter_arg(filter_nodes, stream_name_map): + filter_specs = [_get_filter_spec(i, node, stream_name_map) for i, node in enumerate(filter_nodes)] + return ';'.join(filter_specs) + + +def _get_global_args(node): + if node._name == overwrite_output.__name__: + return ['-y'] + else: + assert False, 'Unsupported global node: {}'.format(node) + + +def _get_output_args(node, stream_name_map): + args = [] + if node._name != merge_outputs.__name__: + stream_name = stream_name_map[node._parents[0]] + if stream_name != '[0]': + args += ['-map', stream_name] + if node._name == file_output.__name__: + args += [node._kwargs['filename']] + else: + assert False, 'Unsupported output node: {}'.format(node) + return args + + +@operator(node_classes={OutputNode, GlobalNode}) +def get_args(parent): + args = [] + # TODO: group nodes together, e.g. `-i somefile -r somerate`. + sorted_nodes, child_map = _topo_sort(parent) + input_nodes = [node for node in sorted_nodes if isinstance(node, InputNode)] + output_nodes = [node for node in sorted_nodes if isinstance(node, OutputNode) and not + isinstance(node, GlobalNode)] + global_nodes = [node for node in sorted_nodes if isinstance(node, GlobalNode)] + filter_nodes = [node for node in sorted_nodes if node not in (input_nodes + output_nodes + global_nodes)] + stream_name_map = {node: _get_stream_name(i) for i, node in enumerate(input_nodes)} + filter_arg = _get_filter_arg(filter_nodes, stream_name_map) + args += reduce(_operator.add, [_get_input_args(node) for node in input_nodes]) + if filter_arg: + args += ['-filter_complex', filter_arg] + args += reduce(_operator.add, [_get_output_args(node, stream_name_map) for node in output_nodes]) + args += reduce(_operator.add, [_get_global_args(node) for node in global_nodes], []) + return args + + +@operator(node_classes={OutputNode, GlobalNode}) +def run(parent, cmd='ffmpeg'): + args = [cmd] + parent.get_args() + subprocess.check_call(args) diff --git a/ffmpeg/tests/test_ffmpeg.py b/ffmpeg/tests/test_ffmpeg.py index 825804d..da1531d 100644 --- a/ffmpeg/tests/test_ffmpeg.py +++ b/ffmpeg/tests/test_ffmpeg.py @@ -18,11 +18,11 @@ def test_fluent_equality(): base1 = ffmpeg.file_input('dummy1.mp4') base2 = ffmpeg.file_input('dummy1.mp4') base3 = ffmpeg.file_input('dummy2.mp4') - t1 = base1.trim(10, 20) - t2 = base1.trim(10, 20) - t3 = base1.trim(10, 30) - t4 = base2.trim(10, 20) - t5 = base3.trim(10, 20) + t1 = base1.trim(start_frame=10, end_frame=20) + t2 = base1.trim(start_frame=10, end_frame=20) + t3 = base1.trim(start_frame=10, end_frame=30) + t4 = base2.trim(start_frame=10, end_frame=20) + t5 = base3.trim(start_frame=10, end_frame=20) assert t1 == t2 assert t1 != t3 assert t1 == t4 @@ -31,9 +31,9 @@ def test_fluent_equality(): def test_fluent_concat(): base = ffmpeg.file_input('dummy.mp4') - trimmed1 = base.trim(10, 20) - trimmed2 = base.trim(30, 40) - trimmed3 = base.trim(50, 60) + trimmed1 = base.trim(start_frame=10, end_frame=20) + trimmed2 = base.trim(start_frame=30, end_frame=40) + trimmed3 = base.trim(start_frame=50, end_frame=60) concat1 = ffmpeg.concat(trimmed1, trimmed2, trimmed3) concat2 = ffmpeg.concat(trimmed1, trimmed2, trimmed3) concat3 = ffmpeg.concat(trimmed1, trimmed3, trimmed2) @@ -47,7 +47,7 @@ def test_fluent_concat(): def test_fluent_output(): ffmpeg \ .file_input('dummy.mp4') \ - .trim(10, 20) \ + .trim(start_frame=10, end_frame=20) \ .file_output('dummy2.mp4') @@ -55,25 +55,25 @@ def test_fluent_complex_filter(): in_file = ffmpeg.file_input('dummy.mp4') return ffmpeg \ .concat( - in_file.trim(10, 20), - in_file.trim(30, 40), - in_file.trim(50, 60) + in_file.trim(start_frame=10, end_frame=20), + in_file.trim(start_frame=30, end_frame=40), + in_file.trim(start_frame=50, end_frame=60) ) \ .file_output('dummy2.mp4') def test_repr(): in_file = ffmpeg.file_input('dummy.mp4') - trim1 = ffmpeg.trim(in_file, 10, 20) - trim2 = ffmpeg.trim(in_file, 30, 40) - trim3 = ffmpeg.trim(in_file, 50, 60) + trim1 = ffmpeg.trim(in_file, start_frame=10, end_frame=20) + trim2 = ffmpeg.trim(in_file, start_frame=30, end_frame=40) + trim3 = ffmpeg.trim(in_file, start_frame=50, end_frame=60) concatted = ffmpeg.concat(trim1, trim2, trim3) output = ffmpeg.file_output(concatted, 'dummy2.mp4') assert repr(in_file) == "file_input(filename='dummy.mp4')" - assert repr(trim1) == "trim(end_frame=20,setpts='PTS-STARTPTS',start_frame=10)" - assert repr(trim2) == "trim(end_frame=40,setpts='PTS-STARTPTS',start_frame=30)" - assert repr(trim3) == "trim(end_frame=60,setpts='PTS-STARTPTS',start_frame=50)" - assert repr(concatted) == "concat()" + assert repr(trim1) == "trim(end_frame=20,start_frame=10)" + assert repr(trim2) == "trim(end_frame=40,start_frame=30)" + assert repr(trim3) == "trim(end_frame=60,start_frame=50)" + assert repr(concatted) == "concat(n=3)" assert repr(output) == "file_output(filename='dummy2.mp4')" @@ -87,8 +87,8 @@ def _get_complex_filter_example(): overlay_file = ffmpeg.file_input(TEST_OVERLAY_FILE) return ffmpeg \ .concat( - in_file.trim(10, 20), - in_file.trim(30, 40), + in_file.trim(start_frame=10, end_frame=20), + in_file.trim(start_frame=30, end_frame=40), ) \ .overlay(overlay_file.hflip()) \ .drawbox(50, 50, 120, 120, color='red', thickness=5) \ @@ -103,8 +103,8 @@ def test_get_args_complex_filter(): '-i', TEST_INPUT_FILE, '-i', TEST_OVERLAY_FILE, '-filter_complex', - '[0]trim=start_frame=10:end_frame=20,setpts=PTS-STARTPTS[v0];' \ - '[0]trim=start_frame=30:end_frame=40,setpts=PTS-STARTPTS[v1];' \ + '[0]trim=end_frame=20:start_frame=10[v0];' \ + '[0]trim=end_frame=40:start_frame=30[v1];' \ '[v0][v1]concat=n=2[v2];' \ '[1]hflip[v3];' \ '[v2][v3]overlay=eof_action=repeat[v4];' \