Handle crlf endings in fix-encoding-pragma
diff --git a/pre_commit_hooks/fix_encoding_pragma.py b/pre_commit_hooks/fix_encoding_pragma.py
index bde4e78..23fc79f 100644
--- a/pre_commit_hooks/fix_encoding_pragma.py
+++ b/pre_commit_hooks/fix_encoding_pragma.py
@@ -9,14 +9,14 @@
from typing import Sequence
from typing import Union
-DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-\n'
+DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-'
def has_coding(line): # type: (bytes) -> bool
if not line.strip():
return False
return (
- line.lstrip()[0:1] == b'#' and (
+ line.lstrip()[:1] == b'#' and (
b'unicode' in line or
b'encoding' in line or
b'coding:' in line or
@@ -26,7 +26,7 @@
class ExpectedContents(collections.namedtuple(
- 'ExpectedContents', ('shebang', 'rest', 'pragma_status'),
+ 'ExpectedContents', ('shebang', 'rest', 'pragma_status', 'ending'),
)):
"""
pragma_status:
@@ -47,6 +47,8 @@
def _get_expected_contents(first_line, second_line, rest, expected_pragma):
# type: (bytes, bytes, bytes, bytes) -> ExpectedContents
+ ending = b'\r\n' if first_line.endswith(b'\r\n') else b'\n'
+
if first_line.startswith(b'#!'):
shebang = first_line
potential_coding = second_line
@@ -55,7 +57,7 @@
potential_coding = first_line
rest = second_line + rest
- if potential_coding == expected_pragma:
+ if potential_coding.rstrip(b'\r\n') == expected_pragma:
pragma_status = True # type: Optional[bool]
elif has_coding(potential_coding):
pragma_status = None
@@ -64,7 +66,7 @@
rest = potential_coding + rest
return ExpectedContents(
- shebang=shebang, rest=rest, pragma_status=pragma_status,
+ shebang=shebang, rest=rest, pragma_status=pragma_status, ending=ending,
)
@@ -93,7 +95,7 @@
f.truncate()
f.write(expected.shebang)
if not remove:
- f.write(expected_pragma)
+ f.write(expected_pragma + expected.ending)
f.write(expected.rest)
return 1
@@ -102,11 +104,7 @@
def _normalize_pragma(pragma): # type: (Union[bytes, str]) -> bytes
if not isinstance(pragma, bytes):
pragma = pragma.encode('UTF-8')
- return pragma.rstrip() + b'\n'
-
-
-def _to_disp(pragma): # type: (bytes) -> str
- return pragma.decode().rstrip()
+ return pragma.rstrip()
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
@@ -117,7 +115,7 @@
parser.add_argument(
'--pragma', default=DEFAULT_PRAGMA, type=_normalize_pragma,
help='The encoding pragma to use. Default: {}'.format(
- _to_disp(DEFAULT_PRAGMA),
+ DEFAULT_PRAGMA.decode(),
),
)
parser.add_argument(
@@ -141,7 +139,7 @@
retv |= file_ret
if file_ret:
print(fmt.format(
- pragma=_to_disp(args.pragma), filename=filename,
+ pragma=args.pragma.decode(), filename=filename,
))
return retv
diff --git a/tests/fix_encoding_pragma_test.py b/tests/fix_encoding_pragma_test.py
index 7288bfa..d94b725 100644
--- a/tests/fix_encoding_pragma_test.py
+++ b/tests/fix_encoding_pragma_test.py
@@ -112,7 +112,7 @@
def test_ok_input_alternate_pragma():
input_s = b'# coding: utf-8\nx = 1\n'
bytesio = io.BytesIO(input_s)
- ret = fix_encoding_pragma(bytesio, expected_pragma=b'# coding: utf-8\n')
+ ret = fix_encoding_pragma(bytesio, expected_pragma=b'# coding: utf-8')
assert ret == 0
bytesio.seek(0)
assert bytesio.read() == input_s
@@ -120,7 +120,7 @@
def test_not_ok_input_alternate_pragma():
bytesio = io.BytesIO(b'x = 1\n')
- ret = fix_encoding_pragma(bytesio, expected_pragma=b'# coding: utf-8\n')
+ ret = fix_encoding_pragma(bytesio, expected_pragma=b'# coding: utf-8')
assert ret == 1
bytesio.seek(0)
assert bytesio.read() == b'# coding: utf-8\nx = 1\n'
@@ -130,11 +130,11 @@
('input_s', 'expected'),
(
# Python 2 cli parameters are bytes
- (b'# coding: utf-8', b'# coding: utf-8\n'),
+ (b'# coding: utf-8', b'# coding: utf-8'),
# Python 3 cli parameters are text
- ('# coding: utf-8', b'# coding: utf-8\n'),
+ ('# coding: utf-8', b'# coding: utf-8'),
# trailing whitespace
- ('# coding: utf-8\n', b'# coding: utf-8\n'),
+ ('# coding: utf-8\n', b'# coding: utf-8'),
),
)
def test_normalize_pragma(input_s, expected):
@@ -150,3 +150,16 @@
assert f.read() == '# coding: utf-8\nx = 1\n'
out, _ = capsys.readouterr()
assert out == 'Added `# coding: utf-8` to {}\n'.format(f.strpath)
+
+
+def test_crlf_ok(tmpdir):
+ f = tmpdir.join('f.py')
+ f.write_binary(b'# -*- coding: utf-8 -*-\r\nx = 1\r\n')
+ assert not main((f.strpath,))
+
+
+def test_crfl_adds(tmpdir):
+ f = tmpdir.join('f.py')
+ f.write_binary(b'x = 1\r\n')
+ assert main((f.strpath,))
+ assert f.read_binary() == b'# -*- coding: utf-8 -*-\r\nx = 1\r\n'