commit ae89e7b095a4d6ad32cbff9f7183f8067154683a Author: Karl Kroening Date: Sat May 13 18:27:06 2017 -1000 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..56d875b --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +venv +.cache +tests/dummy2.mp4 diff --git a/ffmpeg.py b/ffmpeg.py new file mode 100755 index 0000000..5b9903b --- /dev/null +++ b/ffmpeg.py @@ -0,0 +1,252 @@ +#!./venv/bin/python + +from functools import partial +import hashlib +import json +import operator +import subprocess + + +def _create_root_node(node_class, *args, **kwargs): + root = node_class(*args, **kwargs) + root._update_hash() + return root + + +def _create_child_node(node_class, parent, *args, **kwargs): + child = node_class([parent], *args, **kwargs) + child._update_hash() + return child + + +class _Node(object): + def __init__(self, parents): + parent_hashes = [parent.hash for parent in parents] + assert len(parent_hashes) == len(set(parent_hashes)), 'Same node cannot be included as parent multiple times' + self.parents = parents + + @classmethod + def _add_operator(cls, node_class): + if getattr(node_class, 'STATIC', False): + @classmethod + def func(cls2, *args, **kwargs): + return _create_root_node(node_class, *args, **kwargs) + else: + def func(self, *args, **kwargs): + return _create_child_node(node_class, self, *args, **kwargs) + setattr(cls, node_class.NAME, func) + + @classmethod + def _add_operators(cls, node_classes): + [cls._add_operator(node_class) for node_class in node_classes] + + @property + def _props(self): + return {k: v for k, v in self.__dict__.items() if k not in ['parents', 'hash']} + + def __repr__(self): + # TODO: exclude default values. + props = self._props + formatted_props = ['{}={!r}'.format(key, props[key]) for key in sorted(self._props)] + return '{}({})'.format(self.NAME, ','.join(formatted_props)) + + def __eq__(self, other): + return self.hash == other.hash + + def _update_hash(self): + my_hash = hashlib.md5(json.dumps(self._props)).hexdigest() + parent_hashes = [parent.hash for parent in self.parents] + hashes = parent_hashes + [my_hash] + self.hash = hashlib.md5(','.join(hashes)).hexdigest() + + +class _InputNode(_Node): + pass + + +class _FileInputNode(_InputNode): + NAME = 'file_input' + STATIC = True + + def __init__(self, filename): + super(_FileInputNode, self).__init__(parents=[]) + self.filename = filename + + +class _FilterNode(_Node): + pass + + +class _TrimFilterNode(_FilterNode): + NAME = 'trim' + + def __init__(self, parents, start_frame, end_frame, setpts='PTS-STARTPTS'): + super(_TrimFilterNode, self).__init__(parents) + self.start_frame = start_frame + self.end_frame = end_frame + self.setpts = setpts + + +class _ConcatNode(_Node): + NAME = 'concat' + STATIC = True + + def __init__(self, *parents): + super(_ConcatNode, self).__init__(parents) + + +class _OutputNode(_Node): + @classmethod + def _get_stream_name(cls, name): + return '[{}]'.format(name) + + @classmethod + def _get_input_args(cls, input_node): + if isinstance(input_node, _FileInputNode): + args = ['-i', input_node.filename] + else: + assert False, 'Unsupported input node: {}'.format(input_node) + return args + + @classmethod + def _topo_sort(cls, start_node): + marked_nodes = [] + sorted_nodes = [] + child_map = {} + def visit(node, child): + assert node not in marked_nodes, 'Graph is not a DAG' + if child is not None: + if node not in child_map: + child_map[node] = [] + child_map[node].append(child) + if node not in sorted_nodes: + marked_nodes.append(node) + [visit(parent, node) for parent in node.parents] + marked_nodes.remove(node) + sorted_nodes.append(node) + unmarked_nodes = [start_node] + while unmarked_nodes: + visit(unmarked_nodes.pop(), None) + return sorted_nodes, child_map + + @classmethod + def _get_filter(cls, node): + # TODO: find a better way to do this instead of ugly if/elifs. + if isinstance(node, _TrimFilterNode): + return 'trim=start_frame={}:end_frame={},setpts={}'.format(node.start_frame, node.end_frame, node.setpts) + elif isinstance(node, _ConcatNode): + return 'concat=n={}'.format(len(node.parents)) + else: + assert False, 'Unsupported filter node: {}'.format(node) + + @classmethod + def _get_filter_spec(cls, i, node, stream_name_map): + stream_name = cls._get_stream_name('v{}'.format(i)) + stream_name_map[node] = stream_name + inputs = [stream_name_map[parent] for parent in node.parents] + filter_spec = '{}{}{}'.format(''.join(inputs), cls._get_filter(node), stream_name) + return filter_spec + + @classmethod + def _get_filter_arg(cls, filter_nodes, stream_name_map): + filter_specs = [cls._get_filter_spec(i, node, stream_name_map) for i, node in enumerate(filter_nodes)] + return ';'.join(filter_specs) + + @classmethod + def _get_global_args(cls, node): + if isinstance(node, _OverwriteOutputNode): + return ['-y'] + else: + assert False, 'Unsupported global node: {}'.format(node) + + @classmethod + def _get_output_args(cls, node, stream_name_map): + args = [] + if not isinstance(node, _MergeOutputsNode): + stream_name = stream_name_map[node.parents[0]] + if stream_name != '[0]': + args += ['-map', stream_name] + if isinstance(node, _FileOutputNode): + args += [node.filename] + else: + assert False, 'Unsupported output node: {}'.format(node) + return args + + def get_args(self): + args = [] + # TODO: group nodes together, e.g. `-i somefile -r somerate`. + sorted_nodes, child_map = self._topo_sort(self) + input_nodes = [node for node in sorted_nodes if isinstance(node, _InputNode)] + output_nodes = [node for node in sorted_nodes if isinstance(node, _OutputNode) and not + isinstance(node, _GlobalNode)] + global_nodes = [node for node in sorted_nodes if isinstance(node, _GlobalNode)] + filter_nodes = [node for node in sorted_nodes if node not in (input_nodes + output_nodes + global_nodes)] + stream_name_map = {node: self._get_stream_name(i) for i, node in enumerate(input_nodes)} + filter_arg = self._get_filter_arg(filter_nodes, stream_name_map) + args += reduce(operator.add, [self._get_input_args(node) for node in input_nodes]) + if filter_arg: + args += ['-filter_complex', filter_arg] + args += reduce(operator.add, [self._get_output_args(node, stream_name_map) for node in output_nodes]) + args += reduce(operator.add, [self._get_global_args(node) for node in global_nodes], []) + return args + + def run(self): + args = ['ffmpeg'] + self.get_args() + subprocess.check_call(args) + + +class _GlobalNode(_OutputNode): + def __init__(self, parents): + assert len(parents) == 1 + assert isinstance(parents[0], _OutputNode), 'Global nodes can only be attached after output nodes' + super(_GlobalNode, self).__init__(parents) + + +class _OverwriteOutputNode(_GlobalNode): + NAME = 'overwrite_output' + + + +class _MergeOutputsNode(_OutputNode): + NAME = 'merge_outputs' + + def __init__(self, *parents): + assert not any([not isinstance(parent, _OutputNode) for parent in parents]), 'Can only merge output streams' + super(_MergeOutputsNode, self).__init__(*parents) + + +class _FileOutputNode(_OutputNode): + NAME = 'file_output' + + def __init__(self, parents, filename): + super(_FileOutputNode, self).__init__(parents) + self.filename = filename + + +NODE_CLASSES = [ + _ConcatNode, + _FileInputNode, + _FileOutputNode, + _OverwriteOutputNode, + _TrimFilterNode, +] + +_Node._add_operators(NODE_CLASSES) + + +for node_class in NODE_CLASSES: + if getattr(node_class, 'STATIC', False): + func = _create_root_node + else: + func = _create_child_node + globals()[node_class.NAME] = partial(func, node_class) + + +def get_args(node): + assert isinstance(node, _OutputNode), 'Cannot generate ffmpeg args for non-output node' + return node.get_args() + + +def run(node): + assert isinstance(node, _OutputNode), 'Cannot run ffmpeg on non-output node' + return node.run() diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..2423407 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +testpaths = tests +#norecursedirs = venv .git diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e079f8a --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +pytest diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/dummy.mp4 b/tests/dummy.mp4 new file mode 100644 index 0000000..2c7d59e Binary files /dev/null and b/tests/dummy.mp4 differ diff --git a/tests/test_ffmpeg.py b/tests/test_ffmpeg.py new file mode 100644 index 0000000..9a8ef95 --- /dev/null +++ b/tests/test_ffmpeg.py @@ -0,0 +1,104 @@ +import ffmpeg +import os + + +TEST_DIR = os.path.dirname(__file__) +TEST_INPUT_FILE = os.path.join(TEST_DIR, 'dummy.mp4') +TEST_OUTPUT_FILE = os.path.join(TEST_DIR, 'dummy2.mp4') + + +def test_fluent_equality(): + base1 = ffmpeg.file_input('dummy1.mp4') + base2 = ffmpeg.file_input('dummy1.mp4') + base3 = ffmpeg.file_input('dummy2.mp4') + t1 = base1.trim(10, 20) + t2 = base1.trim(10, 20) + t3 = base1.trim(10, 30) + t4 = base2.trim(10, 20) + t5 = base3.trim(10, 20) + assert t1 == t2 + assert t1 != t3 + assert t1 == t4 + assert t1 != t5 + + +def test_fluent_concat(): + base = ffmpeg.file_input('dummy.mp4') + trimmed1 = base.trim(10, 20) + trimmed2 = base.trim(30, 40) + trimmed3 = base.trim(50, 60) + concat1 = ffmpeg.concat(trimmed1, trimmed2, trimmed3) + concat2 = ffmpeg.concat(trimmed1, trimmed2, trimmed3) + concat3 = ffmpeg.concat(trimmed1, trimmed3, trimmed2) + concat4 = ffmpeg.concat() + concat5 = ffmpeg.concat() + assert concat1 == concat2 + assert concat1 != concat3 + assert concat4 == concat5 + + +def test_fluent_output(): + ffmpeg \ + .file_input('dummy.mp4') \ + .trim(10, 20) \ + .file_output('dummy2.mp4') + + +def test_fluent_complex_filter(): + in_file = ffmpeg.file_input('dummy.mp4') + return ffmpeg \ + .concat( + in_file.trim(10, 20), + in_file.trim(30, 40), + in_file.trim(50, 60) + ) \ + .file_output('dummy2.mp4') + + +def test_repr(): + in_file = ffmpeg.file_input('dummy.mp4') + trim1 = ffmpeg.trim(in_file, 10, 20) + trim2 = ffmpeg.trim(in_file, 30, 40) + trim3 = ffmpeg.trim(in_file, 50, 60) + concatted = ffmpeg.concat(trim1, trim2, trim3) + output = ffmpeg.file_output(concatted, 'dummy2.mp4') + assert repr(in_file) == "file_input(filename='dummy.mp4')" + assert repr(trim1) == "trim(end_frame=20,setpts='PTS-STARTPTS',start_frame=10)" + assert repr(trim2) == "trim(end_frame=40,setpts='PTS-STARTPTS',start_frame=30)" + assert repr(trim3) == "trim(end_frame=60,setpts='PTS-STARTPTS',start_frame=50)" + assert repr(concatted) == "concat()" + assert repr(output) == "file_output(filename='dummy2.mp4')" + + +def test_get_args_simple(): + out_file = ffmpeg.file_input('dummy.mp4').file_output('dummy2.mp4') + assert out_file.get_args() == ['-i', 'dummy.mp4', 'dummy2.mp4'] + + +def _get_complex_filter_example(): + in_file = ffmpeg.file_input(TEST_INPUT_FILE) + concatted = ffmpeg.concat( + ffmpeg.trim(in_file, 10, 20), + ffmpeg.trim(in_file, 30, 40), + ffmpeg.trim(in_file, 50, 60), + ) + out = ffmpeg.file_output(concatted, TEST_OUTPUT_FILE) + return ffmpeg.overwrite_output(out) + + +def test_get_args_complex_filter(): + out = _get_complex_filter_example() + assert ffmpeg.get_args(out) == [ + '-i', TEST_INPUT_FILE, + '-filter_complex', + '[0]trim=start_frame=10:end_frame=20,setpts=PTS-STARTPTS[v0];' \ + '[0]trim=start_frame=30:end_frame=40,setpts=PTS-STARTPTS[v1];' \ + '[0]trim=start_frame=50:end_frame=60,setpts=PTS-STARTPTS[v2];' \ + '[v0][v1][v2]concat=n=3[v3]', + '-map', '[v3]', TEST_OUTPUT_FILE, + '-y', + ] + + +def test_run(): + ffmpeg.run(_get_complex_filter_example())