Add a --remove option to fix-encoding-pragma
diff --git a/pre_commit_hooks/fix_encoding_pragma.py b/pre_commit_hooks/fix_encoding_pragma.py
index 48fc9c7..8586937 100644
--- a/pre_commit_hooks/fix_encoding_pragma.py
+++ b/pre_commit_hooks/fix_encoding_pragma.py
@@ -3,7 +3,7 @@
from __future__ import unicode_literals
import argparse
-import io
+import collections
expected_pragma = b'# -*- coding: utf-8 -*-\n'
@@ -21,34 +21,72 @@
)
-def fix_encoding_pragma(f):
- first_line = f.readline()
- second_line = f.readline()
- old = f.read()
- f.seek(0)
+class ExpectedContents(collections.namedtuple(
+ 'ExpectedContents', ('shebang', 'rest', 'pragma_status'),
+)):
+ """
+ pragma_status:
+ - True: has exactly the coding pragma expected
+ - False: missing coding pragma entirely
+ - None: has a coding pragma, but it does not match
+ """
+ __slots__ = ()
- # Ok case: the file is empty
- if not (first_line + second_line + old).strip():
- return 0
+ @property
+ def has_any_pragma(self):
+ return self.pragma_status is not False
- # Ok case: we specify pragma as the first line
- if first_line == expected_pragma:
- return 0
+ def is_expected_pragma(self, remove):
+ expected_pragma_status = not remove
+ return self.pragma_status is expected_pragma_status
- # OK case: we have a shebang as first line and pragma on second line
- if first_line.startswith(b'#!') and second_line == expected_pragma:
- return 0
- # Otherwise we need to rewrite stuff!
+def _get_expected_contents(first_line, second_line, rest):
if first_line.startswith(b'#!'):
- if has_coding(second_line):
- f.write(first_line + expected_pragma + old)
- else:
- f.write(first_line + expected_pragma + second_line + old)
- elif has_coding(first_line):
- f.write(expected_pragma + second_line + old)
+ shebang = first_line
+ potential_coding = second_line
else:
- f.write(expected_pragma + first_line + second_line + old)
+ shebang = b''
+ potential_coding = first_line
+ rest = second_line + rest
+
+ if potential_coding == expected_pragma:
+ pragma_status = True
+ elif has_coding(potential_coding):
+ pragma_status = None
+ else:
+ pragma_status = False
+ rest = potential_coding + rest
+
+ return ExpectedContents(
+ shebang=shebang, rest=rest, pragma_status=pragma_status,
+ )
+
+
+def fix_encoding_pragma(f, remove=False):
+ expected = _get_expected_contents(f.readline(), f.readline(), f.read())
+
+ # Special cases for empty files
+ if not expected.rest.strip():
+ # If a file only has a shebang or a coding pragma, remove it
+ if expected.has_any_pragma or expected.shebang:
+ f.seek(0)
+ f.truncate()
+ f.write(b'')
+ return 1
+ else:
+ return 0
+
+ if expected.is_expected_pragma(remove):
+ return 0
+
+ # Otherwise, write out the new file
+ f.seek(0)
+ f.truncate()
+ f.write(expected.shebang)
+ if not remove:
+ f.write(expected_pragma)
+ f.write(expected.rest)
return 1
@@ -56,18 +94,25 @@
def main(argv=None):
parser = argparse.ArgumentParser('Fixes the encoding pragma of python files')
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
+ parser.add_argument(
+ '--remove', action='store_true',
+ help='Remove the encoding pragma (Useful in a python3-only codebase)',
+ )
args = parser.parse_args(argv)
retv = 0
+ if args.remove:
+ fmt = 'Removed encoding pragma from {filename}'
+ else:
+ fmt = 'Added `{pragma}` to {filename}'
+
for filename in args.filenames:
- with io.open(filename, 'r+b') as f:
- file_ret = fix_encoding_pragma(f)
+ with open(filename, 'r+b') as f:
+ file_ret = fix_encoding_pragma(f, remove=args.remove)
retv |= file_ret
if file_ret:
- print('Added `{0}` to {1}'.format(
- expected_pragma.strip(), filename,
- ))
+ print(fmt.format(pragma=expected_pragma, filename=filename))
return retv