trailing-whitespace hook: Switching from using fileinput to a tempfile and whitespace substitution in binary mode
diff --git a/pre_commit_hooks/trailing_whitespace_fixer.py b/pre_commit_hooks/trailing_whitespace_fixer.py
index 22c41ba..e39a119 100644
--- a/pre_commit_hooks/trailing_whitespace_fixer.py
+++ b/pre_commit_hooks/trailing_whitespace_fixer.py
@@ -1,24 +1,29 @@
from __future__ import print_function
import argparse
-import fileinput
import os
import sys
+import tempfile
from pre_commit_hooks.util import cmd_output
def _fix_file(filename, markdown=False):
- for line in fileinput.input([filename], inplace=True, backup='.bak'):
- # preserve trailing two-space for non-blank lines in markdown files
- if markdown and (not line.isspace()) and (line.endswith(" \n")):
- line = line.rstrip(' \n')
- # only preserve if there are no trailing tabs or unusual whitespace
- if not line[-1].isspace():
- print(line + " ")
- continue
-
- print(line.rstrip())
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
+ with open(filename, 'rb') as original_file:
+ for line in original_file.readlines():
+ # preserve trailing two-space for non-blank lines in markdown files
+ if markdown and (not line.isspace()) and line.endswith(b' \n'):
+ line = line.rstrip(b' \n') # restricted stripping: e.g. \t are not stripped
+ # only preserve if there are no trailing tabs or unusual whitespace
+ if not line[-1:].isspace():
+ tmp_file.write(line + b' \n')
+ else:
+ tmp_file.write(line.rstrip() + b'\n')
+ else:
+ tmp_file.write(line.rstrip() + b'\n')
+ os.remove(filename)
+ os.rename(tmp_file.name, filename)
def fix_trailing_whitespace(argv=None):
@@ -68,15 +73,8 @@
for bad_whitespace_file in bad_whitespace_files:
print('Fixing {0}'.format(bad_whitespace_file))
_, extension = os.path.splitext(bad_whitespace_file.lower())
- try:
- _fix_file(bad_whitespace_file, all_markdown or extension in md_exts)
- return_code = 1
- # pylint: disable=broad-except
- except Exception as error: # pragma: no cover
- # e.g. error can be a UnicodeDecodeError in Python 3
- print('Ignoring {} that caused a {}'.format(bad_whitespace_file, error.__class__))
- os.remove(bad_whitespace_file)
- os.rename(bad_whitespace_file + '.bak', bad_whitespace_file)
+ _fix_file(bad_whitespace_file, all_markdown or extension in md_exts)
+ return_code = 1
return return_code
diff --git a/tests/trailing_whitespace_fixer_test.py b/tests/trailing_whitespace_fixer_test.py
index 78e6e73..3498edd 100644
--- a/tests/trailing_whitespace_fixer_test.py
+++ b/tests/trailing_whitespace_fixer_test.py
@@ -1,8 +1,6 @@
from __future__ import absolute_import
from __future__ import unicode_literals
-import sys
-
import pytest
from pre_commit_hooks.trailing_whitespace_fixer import fix_trailing_whitespace
@@ -108,8 +106,9 @@
def test_preserve_non_utf8_file(tmpdir):
+ non_utf8_bytes_content = b'<a>\xe9 \n</a>\n'
path = tmpdir.join('file.txt')
- path.write_binary(b'<a>\xe9 \n</a>')
+ path.write_binary(non_utf8_bytes_content)
ret = fix_trailing_whitespace([path.strpath])
- assert ret == (1 if sys.version_info[0] < 3 else 0) # a UnicodeDecodeError is only triggered in Python 3
- assert path.size() > 0
+ assert ret == 1
+ assert path.size() == (len(non_utf8_bytes_content) - 1)