Use the tokenizer for great success
diff --git a/pre_commit_hooks/string_fixer.py b/pre_commit_hooks/string_fixer.py
index 9d2213a..9ef7a37 100644
--- a/pre_commit_hooks/string_fixer.py
+++ b/pre_commit_hooks/string_fixer.py
@@ -3,34 +3,60 @@
from __future__ import unicode_literals
import argparse
-import re
+import io
import tokenize
double_quote_starts = tuple(s for s in tokenize.single_quoted if '"' in s)
-compiled_tokenize_string = re.compile('(?<!")' + tokenize.String + '(?!")')
-def handle_match(m):
- string = m.group(0)
+def handle_match(token_text):
+ if '"""' in token_text or "'''" in token_text:
+ return token_text
for double_quote_start in double_quote_starts:
- if string.startswith(double_quote_start):
- meat = string[len(double_quote_start):-1]
+ if token_text.startswith(double_quote_start):
+ meat = token_text[len(double_quote_start):-1]
if '"' in meat or "'" in meat:
break
return double_quote_start.replace('"', "'") + meat + "'"
- return string
+ return token_text
+
+
+def get_line_offsets_by_line_no(src):
+ # Padded so we can index with line number
+ offsets = [None, 0]
+ for line in src.splitlines():
+ offsets.append(offsets[-1] + len(line) + 1)
+ return offsets
def fix_strings(filename):
- contents = open(filename).read()
- new_contents = compiled_tokenize_string.sub(handle_match, contents)
- retval = int(new_contents != contents)
- if retval:
- with open(filename, 'w') as write_handle:
+ contents = io.open(filename).read()
+ line_offsets = get_line_offsets_by_line_no(contents)
+
+ # Basically a mutable string
+ splitcontents = list(contents)
+
+ # Iterate in reverse so the offsets are always correct
+ tokens = reversed(list(tokenize.generate_tokens(
+ io.StringIO(contents).readline,
+ )))
+ for token_type, token_text, (srow, scol), (erow, ecol), _ in tokens:
+ if token_type == tokenize.STRING:
+ new_text = handle_match(token_text)
+ splitcontents[
+ line_offsets[srow] + scol:
+ line_offsets[erow] + ecol
+ ] = new_text
+
+ new_contents = ''.join(splitcontents)
+ if contents != new_contents:
+ with io.open(filename, 'w') as write_handle:
write_handle.write(new_contents)
- return retval
+ return 1
+ else:
+ return 0
def main(argv=None):
diff --git a/tests/string_fixer_test.py b/tests/string_fixer_test.py
index 15b9f19..6305618 100644
--- a/tests/string_fixer_test.py
+++ b/tests/string_fixer_test.py
@@ -2,79 +2,49 @@
from __future__ import print_function
from __future__ import unicode_literals
+import textwrap
+
import pytest
from pre_commit_hooks.string_fixer import main
TESTS = (
# Base cases
- (
- "''",
- "''",
- 0
- ),
- (
- '""',
- "''",
- 1
- ),
- (
- r'"\'"',
- r'"\'"',
- 0
- ),
- (
- r'"\""',
- r'"\""',
- 0
- ),
- (
- r"'\"\"'",
- r"'\"\"'",
- 0
- ),
+ ("''", "''", 0),
+ ('""', "''", 1),
+ (r'"\'"', r'"\'"', 0),
+ (r'"\""', r'"\""', 0),
+ (r"'\"\"'", r"'\"\"'", 0),
# String somewhere in the line
- (
- 'x = "foo"',
- "x = 'foo'",
- 1
- ),
+ ('x = "foo"', "x = 'foo'", 1),
# Test escaped characters
- (
- r'"\'"',
- r'"\'"',
- 0
- ),
+ (r'"\'"', r'"\'"', 0),
# Docstring
+ ('""" Foo """', '""" Foo """', 0),
(
- '""" Foo """',
- '""" Foo """',
- 0
- ),
- # Fuck it, won't even try to fix
- (
- """
- x = " \\n
- foo \\n
+ textwrap.dedent("""
+ x = " \\
+ foo \\
"\n
- """,
- """
- x = " \\n
- foo \\n
- "\n
- """,
- 0
+ """),
+ textwrap.dedent("""
+ x = ' \\
+ foo \\
+ '\n
+ """),
+ 1,
),
+ ('"foo""bar"', "'foo''bar'", 1),
)
-@pytest.mark.parametrize(('input_s', 'expected_output', 'expected_retval'), TESTS)
-def test_rewrite(input_s, expected_output, expected_retval, tmpdir):
+@pytest.mark.parametrize(('input_s', 'output', 'expected_retval'), TESTS)
+def test_rewrite(input_s, output, expected_retval, tmpdir):
tmpfile = tmpdir.join('file.txt')
with open(tmpfile.strpath, 'w') as f:
f.write(input_s)
retval = main([tmpfile.strpath])
- assert tmpfile.read() == expected_output
+ assert tmpfile.read() == output
assert retval == expected_retval