Merge pull request #133 from pre-commit/customizable_encoding_pragma

Allow encoding pragma to be customizable
diff --git a/pre_commit_hooks/fix_encoding_pragma.py b/pre_commit_hooks/fix_encoding_pragma.py
index 8586937..5dcff93 100644
--- a/pre_commit_hooks/fix_encoding_pragma.py
+++ b/pre_commit_hooks/fix_encoding_pragma.py
@@ -5,7 +5,7 @@
 import argparse
 import collections
 
-expected_pragma = b'# -*- coding: utf-8 -*-\n'
+DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-\n'
 
 
 def has_coding(line):
@@ -41,7 +41,7 @@
         return self.pragma_status is expected_pragma_status
 
 
-def _get_expected_contents(first_line, second_line, rest):
+def _get_expected_contents(first_line, second_line, rest, expected_pragma):
     if first_line.startswith(b'#!'):
         shebang = first_line
         potential_coding = second_line
@@ -63,8 +63,10 @@
     )
 
 
-def fix_encoding_pragma(f, remove=False):
-    expected = _get_expected_contents(f.readline(), f.readline(), f.read())
+def fix_encoding_pragma(f, remove=False, expected_pragma=DEFAULT_PRAGMA):
+    expected = _get_expected_contents(
+        f.readline(), f.readline(), f.read(), expected_pragma,
+    )
 
     # Special cases for empty files
     if not expected.rest.strip():
@@ -91,10 +93,26 @@
     return 1
 
 
+def _normalize_pragma(pragma):
+    if not isinstance(pragma, bytes):
+        pragma = pragma.encode('UTF-8')
+    return pragma.rstrip() + b'\n'
+
+
+def _to_disp(pragma):
+    return pragma.decode().rstrip()
+
+
 def main(argv=None):
     parser = argparse.ArgumentParser('Fixes the encoding pragma of python files')
     parser.add_argument('filenames', nargs='*', help='Filenames to fix')
     parser.add_argument(
+        '--pragma', default=DEFAULT_PRAGMA, type=_normalize_pragma,
+        help='The encoding pragma to use.  Default: {}'.format(
+            _to_disp(DEFAULT_PRAGMA),
+        ),
+    )
+    parser.add_argument(
         '--remove', action='store_true',
         help='Remove the encoding pragma (Useful in a python3-only codebase)',
     )
@@ -109,10 +127,14 @@
 
     for filename in args.filenames:
         with open(filename, 'r+b') as f:
-            file_ret = fix_encoding_pragma(f, remove=args.remove)
+            file_ret = fix_encoding_pragma(
+                f, remove=args.remove, expected_pragma=args.pragma,
+            )
             retv |= file_ret
             if file_ret:
-                print(fmt.format(pragma=expected_pragma, filename=filename))
+                print(fmt.format(
+                    pragma=_to_disp(args.pragma), filename=filename,
+                ))
 
     return retv
 
diff --git a/tests/fix_encoding_pragma_test.py b/tests/fix_encoding_pragma_test.py
index a9502a2..d49f1ba 100644
--- a/tests/fix_encoding_pragma_test.py
+++ b/tests/fix_encoding_pragma_test.py
@@ -5,6 +5,7 @@
 
 import pytest
 
+from pre_commit_hooks.fix_encoding_pragma import _normalize_pragma
 from pre_commit_hooks.fix_encoding_pragma import fix_encoding_pragma
 from pre_commit_hooks.fix_encoding_pragma import main
 
@@ -106,3 +107,46 @@
     assert fix_encoding_pragma(bytesio) == 1
     bytesio.seek(0)
     assert bytesio.read() == output
+
+
+def test_ok_input_alternate_pragma():
+    input_s = b'# coding: utf-8\nx = 1\n'
+    bytesio = io.BytesIO(input_s)
+    ret = fix_encoding_pragma(bytesio, expected_pragma=b'# coding: utf-8\n')
+    assert ret == 0
+    bytesio.seek(0)
+    assert bytesio.read() == input_s
+
+
+def test_not_ok_input_alternate_pragma():
+    bytesio = io.BytesIO(b'x = 1\n')
+    ret = fix_encoding_pragma(bytesio, expected_pragma=b'# coding: utf-8\n')
+    assert ret == 1
+    bytesio.seek(0)
+    assert bytesio.read() == b'# coding: utf-8\nx = 1\n'
+
+
+@pytest.mark.parametrize(
+    ('input_s', 'expected'),
+    (
+        # Python 2 cli parameters are bytes
+        (b'# coding: utf-8', b'# coding: utf-8\n'),
+        # Python 3 cli parameters are text
+        ('# coding: utf-8', b'# coding: utf-8\n'),
+        # trailing whitespace
+        ('# coding: utf-8\n', b'# coding: utf-8\n'),
+    ),
+)
+def test_normalize_pragma(input_s, expected):
+    assert _normalize_pragma(input_s) == expected
+
+
+def test_integration_alternate_pragma(tmpdir, capsys):
+    f = tmpdir.join('f.py')
+    f.write('x = 1\n')
+
+    pragma = '# coding: utf-8'
+    assert main((f.strpath, '--pragma', pragma)) == 1
+    assert f.read() == '# coding: utf-8\nx = 1\n'
+    out, _ = capsys.readouterr()
+    assert out == 'Added `# coding: utf-8` to {}\n'.format(f.strpath)