Handle crlf endings in fix-encoding-pragma
diff --git a/pre_commit_hooks/fix_encoding_pragma.py b/pre_commit_hooks/fix_encoding_pragma.py
index bde4e78..23fc79f 100644
--- a/pre_commit_hooks/fix_encoding_pragma.py
+++ b/pre_commit_hooks/fix_encoding_pragma.py
@@ -9,14 +9,14 @@
from typing import Sequence
from typing import Union
-DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-\n'
+DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-'
def has_coding(line): # type: (bytes) -> bool
if not line.strip():
return False
return (
- line.lstrip()[0:1] == b'#' and (
+ line.lstrip()[:1] == b'#' and (
b'unicode' in line or
b'encoding' in line or
b'coding:' in line or
@@ -26,7 +26,7 @@
class ExpectedContents(collections.namedtuple(
- 'ExpectedContents', ('shebang', 'rest', 'pragma_status'),
+ 'ExpectedContents', ('shebang', 'rest', 'pragma_status', 'ending'),
)):
"""
pragma_status:
@@ -47,6 +47,8 @@
def _get_expected_contents(first_line, second_line, rest, expected_pragma):
# type: (bytes, bytes, bytes, bytes) -> ExpectedContents
+ ending = b'\r\n' if first_line.endswith(b'\r\n') else b'\n'
+
if first_line.startswith(b'#!'):
shebang = first_line
potential_coding = second_line
@@ -55,7 +57,7 @@
potential_coding = first_line
rest = second_line + rest
- if potential_coding == expected_pragma:
+ if potential_coding.rstrip(b'\r\n') == expected_pragma:
pragma_status = True # type: Optional[bool]
elif has_coding(potential_coding):
pragma_status = None
@@ -64,7 +66,7 @@
rest = potential_coding + rest
return ExpectedContents(
- shebang=shebang, rest=rest, pragma_status=pragma_status,
+ shebang=shebang, rest=rest, pragma_status=pragma_status, ending=ending,
)
@@ -93,7 +95,7 @@
f.truncate()
f.write(expected.shebang)
if not remove:
- f.write(expected_pragma)
+ f.write(expected_pragma + expected.ending)
f.write(expected.rest)
return 1
@@ -102,11 +104,7 @@
def _normalize_pragma(pragma): # type: (Union[bytes, str]) -> bytes
if not isinstance(pragma, bytes):
pragma = pragma.encode('UTF-8')
- return pragma.rstrip() + b'\n'
-
-
-def _to_disp(pragma): # type: (bytes) -> str
- return pragma.decode().rstrip()
+ return pragma.rstrip()
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
@@ -117,7 +115,7 @@
parser.add_argument(
'--pragma', default=DEFAULT_PRAGMA, type=_normalize_pragma,
help='The encoding pragma to use. Default: {}'.format(
- _to_disp(DEFAULT_PRAGMA),
+ DEFAULT_PRAGMA.decode(),
),
)
parser.add_argument(
@@ -141,7 +139,7 @@
retv |= file_ret
if file_ret:
print(fmt.format(
- pragma=_to_disp(args.pragma), filename=filename,
+ pragma=args.pragma.decode(), filename=filename,
))
return retv