Some style tweaks
diff --git a/pre_commit_hooks/requirements_txt_fixer.py b/pre_commit_hooks/requirements_txt_fixer.py
index 41e1ffc..ffabf2a 100644
--- a/pre_commit_hooks/requirements_txt_fixer.py
+++ b/pre_commit_hooks/requirements_txt_fixer.py
@@ -3,6 +3,10 @@
import argparse
+PASS = 0
+FAIL = 1
+
+
class Requirement(object):
def __init__(self):
@@ -30,14 +34,14 @@
def fix_requirements(f):
requirements = []
- before = list(f)
+ before = tuple(f)
after = []
before_string = b''.join(before)
# If the file is empty (i.e. only whitespace/newlines) exit early
if before_string.strip() == b'':
- return 0
+ return PASS
for line in before:
# If the most recent requirement object has a value, then it's
@@ -60,19 +64,18 @@
requirement.value = line
for requirement in sorted(requirements):
- for comment in requirement.comments:
- after.append(comment)
+ after.extend(requirement.comments)
after.append(requirement.value)
after_string = b''.join(after)
if before_string == after_string:
- return 0
+ return PASS
else:
f.seek(0)
f.write(after_string)
f.truncate()
- return 1
+ return FAIL
def fix_requirements_txt(argv=None):
@@ -80,7 +83,7 @@
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv)
- retv = 0
+ retv = PASS
for arg in args.filenames:
with open(arg, 'rb+') as file_obj:
diff --git a/tests/requirements_txt_fixer_test.py b/tests/requirements_txt_fixer_test.py
index 33f6a47..3681cc6 100644
--- a/tests/requirements_txt_fixer_test.py
+++ b/tests/requirements_txt_fixer_test.py
@@ -1,33 +1,41 @@
import pytest
+from pre_commit_hooks.requirements_txt_fixer import FAIL
from pre_commit_hooks.requirements_txt_fixer import fix_requirements_txt
+from pre_commit_hooks.requirements_txt_fixer import PASS
from pre_commit_hooks.requirements_txt_fixer import Requirement
-# Input, expected return value, expected output
-TESTS = (
- (b'', 0, b''),
- (b'\n', 0, b'\n'),
- (b'foo\nbar\n', 1, b'bar\nfoo\n'),
- (b'bar\nfoo\n', 0, b'bar\nfoo\n'),
- (b'#comment1\nfoo\n#comment2\nbar\n', 1, b'#comment2\nbar\n#comment1\nfoo\n'),
- (b'#comment1\nbar\n#comment2\nfoo\n', 0, b'#comment1\nbar\n#comment2\nfoo\n'),
- (b'#comment\n\nfoo\nbar\n', 1, b'#comment\n\nbar\nfoo\n'),
- (b'#comment\n\nbar\nfoo\n', 0, b'#comment\n\nbar\nfoo\n'),
- (b'\nfoo\nbar\n', 1, b'bar\n\nfoo\n'),
- (b'\nbar\nfoo\n', 0, b'\nbar\nfoo\n'),
- (b'pyramid==1\npyramid-foo==2\n', 0, b'pyramid==1\npyramid-foo==2\n'),
- (b'ocflib\nDjango\nPyMySQL\n', 1, b'Django\nocflib\nPyMySQL\n'),
- (b'-e git+ssh://git_url@tag#egg=ocflib\nDjango\nPyMySQL\n', 1, b'Django\n-e git+ssh://git_url@tag#egg=ocflib\nPyMySQL\n'),
+
+@pytest.mark.parametrize(
+ ('input_s', 'expected_retval', 'output'),
+ (
+ (b'', PASS, b''),
+ (b'\n', PASS, b'\n'),
+ (b'foo\nbar\n', FAIL, b'bar\nfoo\n'),
+ (b'bar\nfoo\n', PASS, b'bar\nfoo\n'),
+ (b'#comment1\nfoo\n#comment2\nbar\n', FAIL, b'#comment2\nbar\n#comment1\nfoo\n'),
+ (b'#comment1\nbar\n#comment2\nfoo\n', PASS, b'#comment1\nbar\n#comment2\nfoo\n'),
+ (b'#comment\n\nfoo\nbar\n', FAIL, b'#comment\n\nbar\nfoo\n'),
+ (b'#comment\n\nbar\nfoo\n', PASS, b'#comment\n\nbar\nfoo\n'),
+ (b'\nfoo\nbar\n', FAIL, b'bar\n\nfoo\n'),
+ (b'\nbar\nfoo\n', PASS, b'\nbar\nfoo\n'),
+ (b'pyramid==1\npyramid-foo==2\n', PASS, b'pyramid==1\npyramid-foo==2\n'),
+ (b'ocflib\nDjango\nPyMySQL\n', FAIL, b'Django\nocflib\nPyMySQL\n'),
+ (
+ b'-e git+ssh://git_url@tag#egg=ocflib\nDjango\nPyMySQL\n',
+ FAIL,
+ b'Django\n-e git+ssh://git_url@tag#egg=ocflib\nPyMySQL\n'
+ ),
+ )
)
-
-
-@pytest.mark.parametrize(('input_s', 'expected_retval', 'output'), TESTS)
def test_integration(input_s, expected_retval, output, tmpdir):
path = tmpdir.join('file.txt')
path.write_binary(input_s)
- assert fix_requirements_txt([path.strpath]) == expected_retval
+ output_retval = fix_requirements_txt([path.strpath])
+
assert path.read_binary() == output
+ assert output_retval == expected_retval
def test_requirement_object():