Merge pull request #133 from pre-commit/customizable_encoding_pragma
Allow encoding pragma to be customizable
diff --git a/pre_commit_hooks/fix_encoding_pragma.py b/pre_commit_hooks/fix_encoding_pragma.py
index 8586937..5dcff93 100644
--- a/pre_commit_hooks/fix_encoding_pragma.py
+++ b/pre_commit_hooks/fix_encoding_pragma.py
@@ -5,7 +5,7 @@
import argparse
import collections
-expected_pragma = b'# -*- coding: utf-8 -*-\n'
+DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-\n'
def has_coding(line):
@@ -41,7 +41,7 @@
return self.pragma_status is expected_pragma_status
-def _get_expected_contents(first_line, second_line, rest):
+def _get_expected_contents(first_line, second_line, rest, expected_pragma):
if first_line.startswith(b'#!'):
shebang = first_line
potential_coding = second_line
@@ -63,8 +63,10 @@
)
-def fix_encoding_pragma(f, remove=False):
- expected = _get_expected_contents(f.readline(), f.readline(), f.read())
+def fix_encoding_pragma(f, remove=False, expected_pragma=DEFAULT_PRAGMA):
+ expected = _get_expected_contents(
+ f.readline(), f.readline(), f.read(), expected_pragma,
+ )
# Special cases for empty files
if not expected.rest.strip():
@@ -91,10 +93,26 @@
return 1
+def _normalize_pragma(pragma):
+ if not isinstance(pragma, bytes):
+ pragma = pragma.encode('UTF-8')
+ return pragma.rstrip() + b'\n'
+
+
+def _to_disp(pragma):
+ return pragma.decode().rstrip()
+
+
def main(argv=None):
parser = argparse.ArgumentParser('Fixes the encoding pragma of python files')
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
parser.add_argument(
+ '--pragma', default=DEFAULT_PRAGMA, type=_normalize_pragma,
+ help='The encoding pragma to use. Default: {}'.format(
+ _to_disp(DEFAULT_PRAGMA),
+ ),
+ )
+ parser.add_argument(
'--remove', action='store_true',
help='Remove the encoding pragma (Useful in a python3-only codebase)',
)
@@ -109,10 +127,14 @@
for filename in args.filenames:
with open(filename, 'r+b') as f:
- file_ret = fix_encoding_pragma(f, remove=args.remove)
+ file_ret = fix_encoding_pragma(
+ f, remove=args.remove, expected_pragma=args.pragma,
+ )
retv |= file_ret
if file_ret:
- print(fmt.format(pragma=expected_pragma, filename=filename))
+ print(fmt.format(
+ pragma=_to_disp(args.pragma), filename=filename,
+ ))
return retv
diff --git a/tests/fix_encoding_pragma_test.py b/tests/fix_encoding_pragma_test.py
index a9502a2..d49f1ba 100644
--- a/tests/fix_encoding_pragma_test.py
+++ b/tests/fix_encoding_pragma_test.py
@@ -5,6 +5,7 @@
import pytest
+from pre_commit_hooks.fix_encoding_pragma import _normalize_pragma
from pre_commit_hooks.fix_encoding_pragma import fix_encoding_pragma
from pre_commit_hooks.fix_encoding_pragma import main
@@ -106,3 +107,46 @@
assert fix_encoding_pragma(bytesio) == 1
bytesio.seek(0)
assert bytesio.read() == output
+
+
+def test_ok_input_alternate_pragma():
+ input_s = b'# coding: utf-8\nx = 1\n'
+ bytesio = io.BytesIO(input_s)
+ ret = fix_encoding_pragma(bytesio, expected_pragma=b'# coding: utf-8\n')
+ assert ret == 0
+ bytesio.seek(0)
+ assert bytesio.read() == input_s
+
+
+def test_not_ok_input_alternate_pragma():
+ bytesio = io.BytesIO(b'x = 1\n')
+ ret = fix_encoding_pragma(bytesio, expected_pragma=b'# coding: utf-8\n')
+ assert ret == 1
+ bytesio.seek(0)
+ assert bytesio.read() == b'# coding: utf-8\nx = 1\n'
+
+
+@pytest.mark.parametrize(
+ ('input_s', 'expected'),
+ (
+ # Python 2 cli parameters are bytes
+ (b'# coding: utf-8', b'# coding: utf-8\n'),
+ # Python 3 cli parameters are text
+ ('# coding: utf-8', b'# coding: utf-8\n'),
+ # trailing whitespace
+ ('# coding: utf-8\n', b'# coding: utf-8\n'),
+ ),
+)
+def test_normalize_pragma(input_s, expected):
+ assert _normalize_pragma(input_s) == expected
+
+
+def test_integration_alternate_pragma(tmpdir, capsys):
+ f = tmpdir.join('f.py')
+ f.write('x = 1\n')
+
+ pragma = '# coding: utf-8'
+ assert main((f.strpath, '--pragma', pragma)) == 1
+ assert f.read() == '# coding: utf-8\nx = 1\n'
+ out, _ = capsys.readouterr()
+ assert out == 'Added `# coding: utf-8` to {}\n'.format(f.strpath)