Apply typing to all of pre-commit-hooks
diff --git a/.gitignore b/.gitignore index c00e966..32c2fec 100644 --- a/.gitignore +++ b/.gitignore
@@ -1,16 +1,11 @@ *.egg-info -*.iml *.py[co] .*.sw[a-z] -.pytest_cache .coverage -.idea -.project -.pydevproject .tox .venv.touch +/.mypy_cache +/.pytest_cache /venv* coverage-html dist -# SublimeText project/workspace files -*.sublime-*
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8bd0fdc..4990537 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml
@@ -27,7 +27,7 @@ rev: v1.3.5 hooks: - id: reorder-python-imports - language_version: python2.7 + language_version: python3 - repo: https://github.com/asottile/pyupgrade rev: v1.11.1 hooks: @@ -36,3 +36,8 @@ rev: v0.7.1 hooks: - id: add-trailing-comma +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.660 + hooks: + - id: mypy + language_version: python3
diff --git a/.travis.yml b/.travis.yml index 477b5c4..fa16cce 100644 --- a/.travis.yml +++ b/.travis.yml
@@ -1,3 +1,4 @@ +dist: xenial language: python matrix: include: # These should match the tox env list @@ -6,9 +7,8 @@ python: 3.6 - env: TOXENV=py37 python: 3.7 - dist: xenial - env: TOXENV=pypy - python: pypy-5.7.1 + python: pypy2.7-5.10.0 install: pip install coveralls tox script: tox before_install:
diff --git a/get-git-lfs.py b/get-git-lfs.py index 48dd31e..4b09cac 100755 --- a/get-git-lfs.py +++ b/get-git-lfs.py
@@ -4,7 +4,9 @@ import os.path import shutil import tarfile -from urllib.request import urlopen +import urllib.request +from typing import cast +from typing import IO DOWNLOAD_PATH = ( 'https://github.com/github/git-lfs/releases/download/' @@ -15,7 +17,7 @@ DEST_DIR = os.path.dirname(DEST_PATH) -def main(): +def main(): # type: () -> int if ( os.path.exists(DEST_PATH) and os.path.isfile(DEST_PATH) and @@ -27,12 +29,13 @@ shutil.rmtree(DEST_DIR, ignore_errors=True) os.makedirs(DEST_DIR, exist_ok=True) - contents = io.BytesIO(urlopen(DOWNLOAD_PATH).read()) + contents = io.BytesIO(urllib.request.urlopen(DOWNLOAD_PATH).read()) with tarfile.open(fileobj=contents) as tar: - with tar.extractfile(PATH_IN_TAR) as src_file: + with cast(IO[bytes], tar.extractfile(PATH_IN_TAR)) as src_file: with open(DEST_PATH, 'wb') as dest_file: shutil.copyfileobj(src_file, dest_file) os.chmod(DEST_PATH, 0o755) + return 0 if __name__ == '__main__':
diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..ee62c89 --- /dev/null +++ b/mypy.ini
@@ -0,0 +1,12 @@ +[mypy] +check_untyped_defs = true +disallow_any_generics = true +disallow_incomplete_defs = true +disallow_untyped_defs = true +no_implicit_optional = true + +[mypy-testing.*] +disallow_untyped_defs = false + +[mypy-tests.*] +disallow_untyped_defs = false
diff --git a/pre_commit_hooks/autopep8_wrapper.py b/pre_commit_hooks/autopep8_wrapper.py index 9951924..8b69a04 100644 --- a/pre_commit_hooks/autopep8_wrapper.py +++ b/pre_commit_hooks/autopep8_wrapper.py
@@ -3,7 +3,7 @@ from __future__ import unicode_literals -def main(argv=None): +def main(): # type: () -> int raise SystemExit( 'autopep8-wrapper is deprecated. Instead use autopep8 directly via ' 'https://github.com/pre-commit/mirrors-autopep8',
diff --git a/pre_commit_hooks/check_added_large_files.py b/pre_commit_hooks/check_added_large_files.py index 2d06706..be39498 100644 --- a/pre_commit_hooks/check_added_large_files.py +++ b/pre_commit_hooks/check_added_large_files.py
@@ -7,13 +7,17 @@ import json import math import os +from typing import Iterable +from typing import Optional +from typing import Sequence +from typing import Set from pre_commit_hooks.util import added_files from pre_commit_hooks.util import CalledProcessError from pre_commit_hooks.util import cmd_output -def lfs_files(): +def lfs_files(): # type: () -> Set[str] try: # Introduced in git-lfs 2.2.0, first working in 2.2.1 lfs_ret = cmd_output('git', 'lfs', 'status', '--json') @@ -24,6 +28,7 @@ def find_large_added_files(filenames, maxkb): + # type: (Iterable[str], int) -> int # Find all added files that are also in the list of files pre-commit tells # us about filenames = (added_files() & set(filenames)) - lfs_files() @@ -38,7 +43,7 @@ return retv -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument( 'filenames', nargs='*',
diff --git a/pre_commit_hooks/check_ast.py b/pre_commit_hooks/check_ast.py index ded65e4..0df3540 100644 --- a/pre_commit_hooks/check_ast.py +++ b/pre_commit_hooks/check_ast.py
@@ -7,9 +7,11 @@ import platform import sys import traceback +from typing import Optional +from typing import Sequence -def check_ast(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*') args = parser.parse_args(argv) @@ -34,4 +36,4 @@ if __name__ == '__main__': - exit(check_ast()) + exit(main())
diff --git a/pre_commit_hooks/check_builtin_literals.py b/pre_commit_hooks/check_builtin_literals.py index 4a4b9ce..874c68c 100644 --- a/pre_commit_hooks/check_builtin_literals.py +++ b/pre_commit_hooks/check_builtin_literals.py
@@ -4,6 +4,10 @@ import ast import collections import sys +from typing import List +from typing import Optional +from typing import Sequence +from typing import Set BUILTIN_TYPES = { @@ -22,14 +26,17 @@ class BuiltinTypeVisitor(ast.NodeVisitor): def __init__(self, ignore=None, allow_dict_kwargs=True): - self.builtin_type_calls = [] + # type: (Optional[Sequence[str]], bool) -> None + self.builtin_type_calls = [] # type: List[BuiltinTypeCall] self.ignore = set(ignore) if ignore else set() self.allow_dict_kwargs = allow_dict_kwargs - def _check_dict_call(self, node): + def _check_dict_call(self, node): # type: (ast.Call) -> bool + return self.allow_dict_kwargs and (getattr(node, 'kwargs', None) or getattr(node, 'keywords', None)) - def visit_Call(self, node): + def visit_Call(self, node): # type: (ast.Call) -> None + if not isinstance(node.func, ast.Name): # Ignore functions that are object attributes (`foo.bar()`). # Assume that if the user calls `builtins.list()`, they know what @@ -47,6 +54,7 @@ def check_file_for_builtin_type_constructors(filename, ignore=None, allow_dict_kwargs=True): + # type: (str, Optional[Sequence[str]], bool) -> List[BuiltinTypeCall] with open(filename, 'rb') as f: tree = ast.parse(f.read(), filename=filename) visitor = BuiltinTypeVisitor(ignore=ignore, allow_dict_kwargs=allow_dict_kwargs) @@ -54,24 +62,22 @@ return visitor.builtin_type_calls -def parse_args(argv): - def parse_ignore(value): - return set(value.split(',')) +def parse_ignore(value): # type: (str) -> Set[str] + return set(value.split(',')) + +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*') parser.add_argument('--ignore', type=parse_ignore, default=set()) - allow_dict_kwargs = parser.add_mutually_exclusive_group(required=False) - allow_dict_kwargs.add_argument('--allow-dict-kwargs', action='store_true') - allow_dict_kwargs.add_argument('--no-allow-dict-kwargs', dest='allow_dict_kwargs', action='store_false') - allow_dict_kwargs.set_defaults(allow_dict_kwargs=True) + mutex = parser.add_mutually_exclusive_group(required=False) + mutex.add_argument('--allow-dict-kwargs', action='store_true') + mutex.add_argument('--no-allow-dict-kwargs', dest='allow_dict_kwargs', action='store_false') + mutex.set_defaults(allow_dict_kwargs=True) - return parser.parse_args(argv) + args = parser.parse_args(argv) - -def main(argv=None): - args = parse_args(argv) rc = 0 for filename in args.filenames: calls = check_file_for_builtin_type_constructors(
diff --git a/pre_commit_hooks/check_byte_order_marker.py b/pre_commit_hooks/check_byte_order_marker.py index 1541b30..10667c3 100644 --- a/pre_commit_hooks/check_byte_order_marker.py +++ b/pre_commit_hooks/check_byte_order_marker.py
@@ -3,9 +3,11 @@ from __future__ import unicode_literals import argparse +from typing import Optional +from typing import Sequence -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to check') args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/check_case_conflict.py b/pre_commit_hooks/check_case_conflict.py index 0f78296..e343d61 100644 --- a/pre_commit_hooks/check_case_conflict.py +++ b/pre_commit_hooks/check_case_conflict.py
@@ -3,16 +3,20 @@ from __future__ import unicode_literals import argparse +from typing import Iterable +from typing import Optional +from typing import Sequence +from typing import Set from pre_commit_hooks.util import added_files from pre_commit_hooks.util import cmd_output -def lower_set(iterable): +def lower_set(iterable): # type: (Iterable[str]) -> Set[str] return {x.lower() for x in iterable} -def find_conflicting_filenames(filenames): +def find_conflicting_filenames(filenames): # type: (Sequence[str]) -> int repo_files = set(cmd_output('git', 'ls-files').splitlines()) relevant_files = set(filenames) | added_files() repo_files -= relevant_files @@ -41,7 +45,7 @@ return retv -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument( 'filenames', nargs='*',
diff --git a/pre_commit_hooks/check_docstring_first.py b/pre_commit_hooks/check_docstring_first.py index 9988378..f4639f1 100644 --- a/pre_commit_hooks/check_docstring_first.py +++ b/pre_commit_hooks/check_docstring_first.py
@@ -5,6 +5,8 @@ import argparse import io import tokenize +from typing import Optional +from typing import Sequence NON_CODE_TOKENS = frozenset(( @@ -13,6 +15,7 @@ def check_docstring_first(src, filename='<unknown>'): + # type: (str, str) -> int """Returns nonzero if the source has what looks like a docstring that is not at the beginning of the source. @@ -50,7 +53,7 @@ return 0 -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*') args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/check_executables_have_shebangs.py b/pre_commit_hooks/check_executables_have_shebangs.py index 89ac6e5..c936a5d 100644 --- a/pre_commit_hooks/check_executables_have_shebangs.py +++ b/pre_commit_hooks/check_executables_have_shebangs.py
@@ -6,9 +6,11 @@ import argparse import pipes import sys +from typing import Optional +from typing import Sequence -def check_has_shebang(path): +def check_has_shebang(path): # type: (str) -> int with open(path, 'rb') as f: first_bytes = f.read(2) @@ -27,7 +29,7 @@ return 0 -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser(description=__doc__) parser.add_argument('filenames', nargs='*') args = parser.parse_args(argv) @@ -38,3 +40,7 @@ retv |= check_has_shebang(filename) return retv + + +if __name__ == '__main__': + exit(main())
diff --git a/pre_commit_hooks/check_json.py b/pre_commit_hooks/check_json.py index b403f4b..b939350 100644 --- a/pre_commit_hooks/check_json.py +++ b/pre_commit_hooks/check_json.py
@@ -4,9 +4,11 @@ import io import json import sys +from typing import Optional +from typing import Sequence -def check_json(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='JSON filenames to check.') args = parser.parse_args(argv) @@ -22,4 +24,4 @@ if __name__ == '__main__': - sys.exit(check_json()) + sys.exit(main())
diff --git a/pre_commit_hooks/check_merge_conflict.py b/pre_commit_hooks/check_merge_conflict.py index 6db5efe..74e4ae1 100644 --- a/pre_commit_hooks/check_merge_conflict.py +++ b/pre_commit_hooks/check_merge_conflict.py
@@ -2,6 +2,9 @@ import argparse import os.path +from typing import Optional +from typing import Sequence + CONFLICT_PATTERNS = [ b'<<<<<<< ', @@ -12,7 +15,7 @@ WARNING_MSG = 'Merge conflict string "{0}" found in {1}:{2}' -def is_in_merge(): +def is_in_merge(): # type: () -> int return ( os.path.exists(os.path.join('.git', 'MERGE_MSG')) and ( @@ -23,7 +26,7 @@ ) -def detect_merge_conflict(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*') parser.add_argument('--assume-in-merge', action='store_true') @@ -47,4 +50,4 @@ if __name__ == '__main__': - exit(detect_merge_conflict()) + exit(main())
diff --git a/pre_commit_hooks/check_symlinks.py b/pre_commit_hooks/check_symlinks.py index 010c871..736bf99 100644 --- a/pre_commit_hooks/check_symlinks.py +++ b/pre_commit_hooks/check_symlinks.py
@@ -4,9 +4,11 @@ import argparse import os.path +from typing import Optional +from typing import Sequence -def check_symlinks(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser(description='Checks for broken symlinks.') parser.add_argument('filenames', nargs='*', help='Filenames to check') args = parser.parse_args(argv) @@ -25,4 +27,4 @@ if __name__ == '__main__': - exit(check_symlinks()) + exit(main())
diff --git a/pre_commit_hooks/check_vcs_permalinks.py b/pre_commit_hooks/check_vcs_permalinks.py index f0dcf5b..f6e2a7d 100644 --- a/pre_commit_hooks/check_vcs_permalinks.py +++ b/pre_commit_hooks/check_vcs_permalinks.py
@@ -5,6 +5,8 @@ import argparse import re import sys +from typing import Optional +from typing import Sequence GITHUB_NON_PERMALINK = re.compile( @@ -12,7 +14,7 @@ ) -def _check_filename(filename): +def _check_filename(filename): # type: (str) -> int retv = 0 with open(filename, 'rb') as f: for i, line in enumerate(f, 1): @@ -24,7 +26,7 @@ return retv -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*') args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/check_xml.py b/pre_commit_hooks/check_xml.py index a4c11a5..66e10ba 100644 --- a/pre_commit_hooks/check_xml.py +++ b/pre_commit_hooks/check_xml.py
@@ -5,10 +5,12 @@ import argparse import io import sys -import xml.sax +import xml.sax.handler +from typing import Optional +from typing import Sequence -def check_xml(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='XML filenames to check.') args = parser.parse_args(argv) @@ -17,7 +19,7 @@ for filename in args.filenames: try: with io.open(filename, 'rb') as xml_file: - xml.sax.parse(xml_file, xml.sax.ContentHandler()) + xml.sax.parse(xml_file, xml.sax.handler.ContentHandler()) except xml.sax.SAXException as exc: print('{}: Failed to xml parse ({})'.format(filename, exc)) retval = 1 @@ -25,4 +27,4 @@ if __name__ == '__main__': - sys.exit(check_xml()) + sys.exit(main())
diff --git a/pre_commit_hooks/check_yaml.py b/pre_commit_hooks/check_yaml.py index 208737f..b638684 100644 --- a/pre_commit_hooks/check_yaml.py +++ b/pre_commit_hooks/check_yaml.py
@@ -3,22 +3,26 @@ import argparse import collections import sys +from typing import Any +from typing import Generator +from typing import Optional +from typing import Sequence import ruamel.yaml yaml = ruamel.yaml.YAML(typ='safe') -def _exhaust(gen): +def _exhaust(gen): # type: (Generator[str, None, None]) -> None for _ in gen: pass -def _parse_unsafe(*args, **kwargs): +def _parse_unsafe(*args, **kwargs): # type: (*Any, **Any) -> None _exhaust(yaml.parse(*args, **kwargs)) -def _load_all(*args, **kwargs): +def _load_all(*args, **kwargs): # type: (*Any, **Any) -> None _exhaust(yaml.load_all(*args, **kwargs)) @@ -31,7 +35,7 @@ } -def check_yaml(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument( '-m', '--multi', '--allow-multiple-documents', action='store_true', @@ -63,4 +67,4 @@ if __name__ == '__main__': - sys.exit(check_yaml()) + sys.exit(main())
diff --git a/pre_commit_hooks/debug_statement_hook.py b/pre_commit_hooks/debug_statement_hook.py index 5d32277..02dd3b2 100644 --- a/pre_commit_hooks/debug_statement_hook.py +++ b/pre_commit_hooks/debug_statement_hook.py
@@ -5,6 +5,9 @@ import ast import collections import traceback +from typing import List +from typing import Optional +from typing import Sequence DEBUG_STATEMENTS = {'pdb', 'ipdb', 'pudb', 'q', 'rdb'} @@ -12,21 +15,21 @@ class DebugStatementParser(ast.NodeVisitor): - def __init__(self): - self.breakpoints = [] + def __init__(self): # type: () -> None + self.breakpoints = [] # type: List[Debug] - def visit_Import(self, node): + def visit_Import(self, node): # type: (ast.Import) -> None for name in node.names: if name.name in DEBUG_STATEMENTS: st = Debug(node.lineno, node.col_offset, name.name, 'imported') self.breakpoints.append(st) - def visit_ImportFrom(self, node): + def visit_ImportFrom(self, node): # type: (ast.ImportFrom) -> None if node.module in DEBUG_STATEMENTS: st = Debug(node.lineno, node.col_offset, node.module, 'imported') self.breakpoints.append(st) - def visit_Call(self, node): + def visit_Call(self, node): # type: (ast.Call) -> None """python3.7+ breakpoint()""" if isinstance(node.func, ast.Name) and node.func.id == 'breakpoint': st = Debug(node.lineno, node.col_offset, node.func.id, 'called') @@ -34,7 +37,7 @@ self.generic_visit(node) -def check_file(filename): +def check_file(filename): # type: (str) -> int try: with open(filename, 'rb') as f: ast_obj = ast.parse(f.read(), filename=filename) @@ -58,7 +61,7 @@ return int(bool(visitor.breakpoints)) -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to run') args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/detect_aws_credentials.py b/pre_commit_hooks/detect_aws_credentials.py index ecd9d40..3c87d11 100644 --- a/pre_commit_hooks/detect_aws_credentials.py +++ b/pre_commit_hooks/detect_aws_credentials.py
@@ -3,11 +3,16 @@ import argparse import os +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Set from six.moves import configparser -def get_aws_credential_files_from_env(): +def get_aws_credential_files_from_env(): # type: () -> Set[str] """Extract credential file paths from environment variables.""" files = set() for env_var in ( @@ -19,7 +24,7 @@ return files -def get_aws_secrets_from_env(): +def get_aws_secrets_from_env(): # type: () -> Set[str] """Extract AWS secrets from environment variables.""" keys = set() for env_var in ( @@ -30,7 +35,7 @@ return keys -def get_aws_secrets_from_file(credentials_file): +def get_aws_secrets_from_file(credentials_file): # type: (str) -> Set[str] """Extract AWS secrets from configuration files. Read an ini-style configuration file and return a set with all found AWS @@ -62,6 +67,7 @@ def check_file_for_aws_keys(filenames, keys): + # type: (Sequence[str], Set[str]) -> List[Dict[str, str]] """Check if files contain AWS secrets. Return a list of all files containing AWS secrets and keys found, with all @@ -82,7 +88,7 @@ return bad_files -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='+', help='Filenames to run') parser.add_argument( @@ -111,7 +117,7 @@ # of files to to gather AWS secrets from. credential_files |= get_aws_credential_files_from_env() - keys = set() + keys = set() # type: Set[str] for credential_file in credential_files: keys |= get_aws_secrets_from_file(credential_file)
diff --git a/pre_commit_hooks/detect_private_key.py b/pre_commit_hooks/detect_private_key.py index c8ee961..d31957d 100644 --- a/pre_commit_hooks/detect_private_key.py +++ b/pre_commit_hooks/detect_private_key.py
@@ -2,6 +2,8 @@ import argparse import sys +from typing import Optional +from typing import Sequence BLACKLIST = [ b'BEGIN RSA PRIVATE KEY', @@ -15,7 +17,7 @@ ] -def detect_private_key(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to check') args = parser.parse_args(argv) @@ -37,4 +39,4 @@ if __name__ == '__main__': - sys.exit(detect_private_key()) + sys.exit(main())
diff --git a/pre_commit_hooks/end_of_file_fixer.py b/pre_commit_hooks/end_of_file_fixer.py index 5ab1b7b..4e77c94 100644 --- a/pre_commit_hooks/end_of_file_fixer.py +++ b/pre_commit_hooks/end_of_file_fixer.py
@@ -4,9 +4,12 @@ import argparse import os import sys +from typing import IO +from typing import Optional +from typing import Sequence -def fix_file(file_obj): +def fix_file(file_obj): # type: (IO[bytes]) -> int # Test for newline at end of file # Empty files will throw IOError here try: @@ -49,7 +52,7 @@ return 0 -def end_of_file_fixer(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to fix') args = parser.parse_args(argv) @@ -68,4 +71,4 @@ if __name__ == '__main__': - sys.exit(end_of_file_fixer()) + sys.exit(main())
diff --git a/pre_commit_hooks/file_contents_sorter.py b/pre_commit_hooks/file_contents_sorter.py index fe7f7ee..6f13c98 100644 --- a/pre_commit_hooks/file_contents_sorter.py +++ b/pre_commit_hooks/file_contents_sorter.py
@@ -12,12 +12,15 @@ from __future__ import print_function import argparse +from typing import IO +from typing import Optional +from typing import Sequence PASS = 0 FAIL = 1 -def sort_file_contents(f): +def sort_file_contents(f): # type: (IO[bytes]) -> int before = list(f) after = sorted([line.strip(b'\n\r') for line in before if line.strip()]) @@ -33,7 +36,7 @@ return FAIL -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='+', help='Files to sort') args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/fix_encoding_pragma.py b/pre_commit_hooks/fix_encoding_pragma.py index 3bf234e..b0b5c8e 100644 --- a/pre_commit_hooks/fix_encoding_pragma.py +++ b/pre_commit_hooks/fix_encoding_pragma.py
@@ -4,11 +4,15 @@ import argparse import collections +from typing import IO +from typing import Optional +from typing import Sequence +from typing import Union DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-\n' -def has_coding(line): +def has_coding(line): # type: (bytes) -> bool if not line.strip(): return False return ( @@ -33,15 +37,16 @@ __slots__ = () @property - def has_any_pragma(self): + def has_any_pragma(self): # type: () -> bool return self.pragma_status is not False - def is_expected_pragma(self, remove): + def is_expected_pragma(self, remove): # type: (bool) -> bool expected_pragma_status = not remove return self.pragma_status is expected_pragma_status def _get_expected_contents(first_line, second_line, rest, expected_pragma): + # type: (bytes, bytes, bytes, bytes) -> ExpectedContents if first_line.startswith(b'#!'): shebang = first_line potential_coding = second_line @@ -51,7 +56,7 @@ rest = second_line + rest if potential_coding == expected_pragma: - pragma_status = True + pragma_status = True # type: Optional[bool] elif has_coding(potential_coding): pragma_status = None else: @@ -64,6 +69,7 @@ def fix_encoding_pragma(f, remove=False, expected_pragma=DEFAULT_PRAGMA): + # type: (IO[bytes], bool, bytes) -> int expected = _get_expected_contents( f.readline(), f.readline(), f.read(), expected_pragma, ) @@ -93,17 +99,17 @@ return 1 -def _normalize_pragma(pragma): +def _normalize_pragma(pragma): # type: (Union[bytes, str]) -> bytes if not isinstance(pragma, bytes): pragma = pragma.encode('UTF-8') return pragma.rstrip() + b'\n' -def _to_disp(pragma): +def _to_disp(pragma): # type: (bytes) -> str return pragma.decode().rstrip() -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser('Fixes the encoding pragma of python files') parser.add_argument('filenames', nargs='*', help='Filenames to fix') parser.add_argument(
diff --git a/pre_commit_hooks/forbid_new_submodules.py b/pre_commit_hooks/forbid_new_submodules.py index c9464cf..bdbd6f7 100644 --- a/pre_commit_hooks/forbid_new_submodules.py +++ b/pre_commit_hooks/forbid_new_submodules.py
@@ -2,10 +2,13 @@ from __future__ import print_function from __future__ import unicode_literals +from typing import Optional +from typing import Sequence + from pre_commit_hooks.util import cmd_output -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int # `argv` is ignored, pre-commit will send us a list of files that we # don't care about added_diff = cmd_output(
diff --git a/pre_commit_hooks/mixed_line_ending.py b/pre_commit_hooks/mixed_line_ending.py index e35a65c..90aef03 100644 --- a/pre_commit_hooks/mixed_line_ending.py +++ b/pre_commit_hooks/mixed_line_ending.py
@@ -4,6 +4,9 @@ import argparse import collections +from typing import Dict +from typing import Optional +from typing import Sequence CRLF = b'\r\n' @@ -14,7 +17,7 @@ FIX_TO_LINE_ENDING = {'cr': CR, 'crlf': CRLF, 'lf': LF} -def _fix(filename, contents, ending): +def _fix(filename, contents, ending): # type: (str, bytes, bytes) -> None new_contents = b''.join( line.rstrip(b'\r\n') + ending for line in contents.splitlines(True) ) @@ -22,11 +25,11 @@ f.write(new_contents) -def fix_filename(filename, fix): +def fix_filename(filename, fix): # type: (str, str) -> int with open(filename, 'rb') as f: contents = f.read() - counts = collections.defaultdict(int) + counts = collections.defaultdict(int) # type: Dict[bytes, int] for line in contents.splitlines(True): for ending in ALL_ENDINGS: @@ -63,7 +66,7 @@ return other_endings -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument( '-f', '--fix',
diff --git a/pre_commit_hooks/no_commit_to_branch.py b/pre_commit_hooks/no_commit_to_branch.py index fdd146b..6b68c91 100644 --- a/pre_commit_hooks/no_commit_to_branch.py +++ b/pre_commit_hooks/no_commit_to_branch.py
@@ -1,12 +1,15 @@ from __future__ import print_function import argparse +from typing import Optional +from typing import Sequence +from typing import Set from pre_commit_hooks.util import CalledProcessError from pre_commit_hooks.util import cmd_output -def is_on_branch(protected): +def is_on_branch(protected): # type: (Set[str]) -> bool try: branch = cmd_output('git', 'symbolic-ref', 'HEAD') except CalledProcessError: @@ -15,7 +18,7 @@ return '/'.join(chunks[2:]) in protected -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument( '-b', '--branch', action='append',
diff --git a/pre_commit_hooks/pretty_format_json.py b/pre_commit_hooks/pretty_format_json.py index 363037e..de7f8d7 100644 --- a/pre_commit_hooks/pretty_format_json.py +++ b/pre_commit_hooks/pretty_format_json.py
@@ -5,12 +5,20 @@ import json import sys from collections import OrderedDict +from typing import List +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union from six import text_type -def _get_pretty_format(contents, indent, ensure_ascii=True, sort_keys=True, top_keys=[]): +def _get_pretty_format(contents, indent, ensure_ascii=True, sort_keys=True, top_keys=()): + # type: (str, str, bool, bool, Sequence[str]) -> str def pairs_first(pairs): + # type: (Sequence[Tuple[str, str]]) -> Mapping[str, str] before = [pair for pair in pairs if pair[0] in top_keys] before = sorted(before, key=lambda x: top_keys.index(x[0])) after = [pair for pair in pairs if pair[0] not in top_keys] @@ -27,13 +35,13 @@ return text_type(json_pretty) + '\n' -def _autofix(filename, new_contents): +def _autofix(filename, new_contents): # type: (str, str) -> None print('Fixing file {}'.format(filename)) with io.open(filename, 'w', encoding='UTF-8') as f: f.write(new_contents) -def parse_num_to_int(s): +def parse_num_to_int(s): # type: (str) -> Union[int, str] """Convert string numbers to int, leaving strings as is.""" try: return int(s) @@ -41,11 +49,11 @@ return s -def parse_topkeys(s): +def parse_topkeys(s): # type: (str) -> List[str] return s.split(',') -def pretty_format_json(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument( '--autofix', @@ -117,4 +125,4 @@ if __name__ == '__main__': - sys.exit(pretty_format_json()) + sys.exit(main())
diff --git a/pre_commit_hooks/requirements_txt_fixer.py b/pre_commit_hooks/requirements_txt_fixer.py index 6dcf8d0..3f85a17 100644 --- a/pre_commit_hooks/requirements_txt_fixer.py +++ b/pre_commit_hooks/requirements_txt_fixer.py
@@ -1,6 +1,10 @@ from __future__ import print_function import argparse +from typing import IO +from typing import List +from typing import Optional +from typing import Sequence PASS = 0 @@ -9,21 +13,23 @@ class Requirement(object): - def __init__(self): + def __init__(self): # type: () -> None super(Requirement, self).__init__() - self.value = None - self.comments = [] + self.value = None # type: Optional[bytes] + self.comments = [] # type: List[bytes] @property - def name(self): + def name(self): # type: () -> bytes + assert self.value is not None, self.value if self.value.startswith(b'-e '): return self.value.lower().partition(b'=')[-1] return self.value.lower().partition(b'==')[0] - def __lt__(self, requirement): + def __lt__(self, requirement): # type: (Requirement) -> int # \n means top of file comment, so always return True, # otherwise just do a string comparison with value. + assert self.value is not None, self.value if self.value == b'\n': return True elif requirement.value == b'\n': @@ -32,10 +38,10 @@ return self.name < requirement.name -def fix_requirements(f): - requirements = [] +def fix_requirements(f): # type: (IO[bytes]) -> int + requirements = [] # type: List[Requirement] before = tuple(f) - after = [] + after = [] # type: List[bytes] before_string = b''.join(before) @@ -46,6 +52,7 @@ for line in before: # If the most recent requirement object has a value, then it's # time to start building the next requirement object. + if not len(requirements) or requirements[-1].value is not None: requirements.append(Requirement()) @@ -78,6 +85,7 @@ for requirement in sorted(requirements): after.extend(requirement.comments) + assert requirement.value, requirement.value after.append(requirement.value) after.extend(rest) @@ -92,7 +100,7 @@ return FAIL -def fix_requirements_txt(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to fix') args = parser.parse_args(argv) @@ -109,3 +117,7 @@ retv |= ret_for_file return retv + + +if __name__ == '__main__': + exit(main())
diff --git a/pre_commit_hooks/sort_simple_yaml.py b/pre_commit_hooks/sort_simple_yaml.py index 7afae91..3c8ef16 100755 --- a/pre_commit_hooks/sort_simple_yaml.py +++ b/pre_commit_hooks/sort_simple_yaml.py
@@ -21,12 +21,15 @@ from __future__ import print_function import argparse +from typing import List +from typing import Optional +from typing import Sequence QUOTES = ["'", '"'] -def sort(lines): +def sort(lines): # type: (List[str]) -> List[str] """Sort a YAML file in alphabetical order, keeping blocks together. :param lines: array of strings (without newlines) @@ -44,7 +47,7 @@ return new_lines -def parse_block(lines, header=False): +def parse_block(lines, header=False): # type: (List[str], bool) -> List[str] """Parse and return a single block, popping off the start of `lines`. If parsing a header block, we stop after we reach a line that is not a @@ -60,7 +63,7 @@ return block_lines -def parse_blocks(lines): +def parse_blocks(lines): # type: (List[str]) -> List[List[str]] """Parse and return all possible blocks, popping off the start of `lines`. :param lines: list of lines @@ -77,7 +80,7 @@ return blocks -def first_key(lines): +def first_key(lines): # type: (List[str]) -> str """Returns a string representing the sort key of a block. The sort key is the first YAML key we encounter, ignoring comments, and @@ -95,9 +98,11 @@ if any(line.startswith(quote) for quote in QUOTES): return line[1:] return line + else: + return '' # not actually reached in reality -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to fix') args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/string_fixer.py b/pre_commit_hooks/string_fixer.py index c432682..a5ea1ea 100644 --- a/pre_commit_hooks/string_fixer.py +++ b/pre_commit_hooks/string_fixer.py
@@ -4,34 +4,39 @@ import argparse import io +import re import tokenize +from typing import List +from typing import Optional +from typing import Sequence + +START_QUOTE_RE = re.compile('^[a-zA-Z]*"') -double_quote_starts = tuple(s for s in tokenize.single_quoted if '"' in s) - - -def handle_match(token_text): +def handle_match(token_text): # type: (str) -> str if '"""' in token_text or "'''" in token_text: return token_text - for double_quote_start in double_quote_starts: - if token_text.startswith(double_quote_start): - meat = token_text[len(double_quote_start):-1] - if '"' in meat or "'" in meat: - break - return double_quote_start.replace('"', "'") + meat + "'" - return token_text + match = START_QUOTE_RE.match(token_text) + if match is not None: + meat = token_text[match.end():-1] + if '"' in meat or "'" in meat: + return token_text + else: + return match.group().replace('"', "'") + meat + "'" + else: + return token_text -def get_line_offsets_by_line_no(src): +def get_line_offsets_by_line_no(src): # type: (str) -> List[int] # Padded so we can index with line number - offsets = [None, 0] + offsets = [-1, 0] for line in src.splitlines(): offsets.append(offsets[-1] + len(line) + 1) return offsets -def fix_strings(filename): +def fix_strings(filename): # type: (str) -> int with io.open(filename, encoding='UTF-8') as f: contents = f.read() line_offsets = get_line_offsets_by_line_no(contents) @@ -60,7 +65,7 @@ return 0 -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to fix') args = parser.parse_args(argv) @@ -74,3 +79,7 @@ retv |= return_value return retv + + +if __name__ == '__main__': + exit(main())
diff --git a/pre_commit_hooks/tests_should_end_in_test.py b/pre_commit_hooks/tests_should_end_in_test.py index 9bea20d..7a1e7c0 100644 --- a/pre_commit_hooks/tests_should_end_in_test.py +++ b/pre_commit_hooks/tests_should_end_in_test.py
@@ -1,12 +1,14 @@ from __future__ import print_function import argparse +import os.path import re import sys -from os.path import basename +from typing import Optional +from typing import Sequence -def validate_files(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*') parser.add_argument( @@ -18,7 +20,7 @@ retcode = 0 test_name_pattern = 'test.*.py' if args.django else '.*_test.py' for filename in args.filenames: - base = basename(filename) + base = os.path.basename(filename) if ( not re.match(test_name_pattern, base) and not base == '__init__.py' and @@ -35,4 +37,4 @@ if __name__ == '__main__': - sys.exit(validate_files()) + sys.exit(main())
diff --git a/pre_commit_hooks/trailing_whitespace_fixer.py b/pre_commit_hooks/trailing_whitespace_fixer.py index 1b54fbd..4fe7975 100644 --- a/pre_commit_hooks/trailing_whitespace_fixer.py +++ b/pre_commit_hooks/trailing_whitespace_fixer.py
@@ -3,9 +3,11 @@ import argparse import os import sys +from typing import Optional +from typing import Sequence -def _fix_file(filename, is_markdown): +def _fix_file(filename, is_markdown): # type: (str, bool) -> bool with open(filename, mode='rb') as file_processed: lines = file_processed.readlines() newlines = [_process_line(line, is_markdown) for line in lines] @@ -18,7 +20,7 @@ return False -def _process_line(line, is_markdown): +def _process_line(line, is_markdown): # type: (bytes, bool) -> bytes if line[-2:] == b'\r\n': eol = b'\r\n' elif line[-1:] == b'\n': @@ -31,7 +33,7 @@ return line.rstrip() + eol -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument( '--no-markdown-linebreak-ext',
diff --git a/pre_commit_hooks/util.py b/pre_commit_hooks/util.py index 269b553..5d1d11b 100644 --- a/pre_commit_hooks/util.py +++ b/pre_commit_hooks/util.py
@@ -3,23 +3,25 @@ from __future__ import unicode_literals import subprocess +from typing import Any +from typing import Set class CalledProcessError(RuntimeError): pass -def added_files(): +def added_files(): # type: () -> Set[str] return set(cmd_output( 'git', 'diff', '--staged', '--name-only', '--diff-filter=A', ).splitlines()) -def cmd_output(*cmd, **kwargs): +def cmd_output(*cmd, **kwargs): # type: (*str, **Any) -> str retcode = kwargs.pop('retcode', 0) - popen_kwargs = {'stdout': subprocess.PIPE, 'stderr': subprocess.PIPE} - popen_kwargs.update(kwargs) - proc = subprocess.Popen(cmd, **popen_kwargs) + kwargs.setdefault('stdout', subprocess.PIPE) + kwargs.setdefault('stderr', subprocess.PIPE) + proc = subprocess.Popen(cmd, **kwargs) stdout, stderr = proc.communicate() stdout = stdout.decode('UTF-8') if stderr is not None:
diff --git a/setup.py b/setup.py index 84892a7..756500b 100644 --- a/setup.py +++ b/setup.py
@@ -28,35 +28,36 @@ 'ruamel.yaml>=0.15', 'six', ], + extras_require={':python_version<"3.5"': ['typing']}, entry_points={ 'console_scripts': [ 'autopep8-wrapper = pre_commit_hooks.autopep8_wrapper:main', 'check-added-large-files = pre_commit_hooks.check_added_large_files:main', - 'check-ast = pre_commit_hooks.check_ast:check_ast', + 'check-ast = pre_commit_hooks.check_ast:main', 'check-builtin-literals = pre_commit_hooks.check_builtin_literals:main', 'check-byte-order-marker = pre_commit_hooks.check_byte_order_marker:main', 'check-case-conflict = pre_commit_hooks.check_case_conflict:main', 'check-docstring-first = pre_commit_hooks.check_docstring_first:main', 'check-executables-have-shebangs = pre_commit_hooks.check_executables_have_shebangs:main', - 'check-json = pre_commit_hooks.check_json:check_json', - 'check-merge-conflict = pre_commit_hooks.check_merge_conflict:detect_merge_conflict', - 'check-symlinks = pre_commit_hooks.check_symlinks:check_symlinks', + 'check-json = pre_commit_hooks.check_json:main', + 'check-merge-conflict = pre_commit_hooks.check_merge_conflict:main', + 'check-symlinks = pre_commit_hooks.check_symlinks:main', 'check-vcs-permalinks = pre_commit_hooks.check_vcs_permalinks:main', - 'check-xml = pre_commit_hooks.check_xml:check_xml', - 'check-yaml = pre_commit_hooks.check_yaml:check_yaml', + 'check-xml = pre_commit_hooks.check_xml:main', + 'check-yaml = pre_commit_hooks.check_yaml:main', 'debug-statement-hook = pre_commit_hooks.debug_statement_hook:main', 'detect-aws-credentials = pre_commit_hooks.detect_aws_credentials:main', - 'detect-private-key = pre_commit_hooks.detect_private_key:detect_private_key', + 'detect-private-key = pre_commit_hooks.detect_private_key:main', 'double-quote-string-fixer = pre_commit_hooks.string_fixer:main', - 'end-of-file-fixer = pre_commit_hooks.end_of_file_fixer:end_of_file_fixer', + 'end-of-file-fixer = pre_commit_hooks.end_of_file_fixer:main', 'file-contents-sorter = pre_commit_hooks.file_contents_sorter:main', 'fix-encoding-pragma = pre_commit_hooks.fix_encoding_pragma:main', 'forbid-new-submodules = pre_commit_hooks.forbid_new_submodules:main', 'mixed-line-ending = pre_commit_hooks.mixed_line_ending:main', - 'name-tests-test = pre_commit_hooks.tests_should_end_in_test:validate_files', + 'name-tests-test = pre_commit_hooks.tests_should_end_in_test:main', 'no-commit-to-branch = pre_commit_hooks.no_commit_to_branch:main', - 'pretty-format-json = pre_commit_hooks.pretty_format_json:pretty_format_json', - 'requirements-txt-fixer = pre_commit_hooks.requirements_txt_fixer:fix_requirements_txt', + 'pretty-format-json = pre_commit_hooks.pretty_format_json:main', + 'requirements-txt-fixer = pre_commit_hooks.requirements_txt_fixer:main', 'sort-simple-yaml = pre_commit_hooks.sort_simple_yaml:main', 'trailing-whitespace-fixer = pre_commit_hooks.trailing_whitespace_fixer:main', ],
diff --git a/testing/resources/bad_json_latin1.nonjson b/testing/resources/bad_json_latin1.nonjson old mode 100755 new mode 100644
diff --git a/testing/resources/builtin_constructors.py b/testing/resources/builtin_constructors.py deleted file mode 100644 index 174a9e8..0000000 --- a/testing/resources/builtin_constructors.py +++ /dev/null
@@ -1,17 +0,0 @@ -from six.moves import builtins - -c1 = complex() -d1 = dict() -f1 = float() -i1 = int() -l1 = list() -s1 = str() -t1 = tuple() - -c2 = builtins.complex() -d2 = builtins.dict() -f2 = builtins.float() -i2 = builtins.int() -l2 = builtins.list() -s2 = builtins.str() -t2 = builtins.tuple()
diff --git a/testing/resources/builtin_literals.py b/testing/resources/builtin_literals.py deleted file mode 100644 index 8513b70..0000000 --- a/testing/resources/builtin_literals.py +++ /dev/null
@@ -1,7 +0,0 @@ -c1 = 0j -d1 = {} -f1 = 0.0 -i1 = 0 -l1 = [] -s1 = '' -t1 = ()
diff --git a/tests/check_ast_test.py b/tests/check_ast_test.py index 64916ba..c16f5fc 100644 --- a/tests/check_ast_test.py +++ b/tests/check_ast_test.py
@@ -1,15 +1,15 @@ from __future__ import absolute_import from __future__ import unicode_literals -from pre_commit_hooks.check_ast import check_ast +from pre_commit_hooks.check_ast import main from testing.util import get_resource_path def test_failing_file(): - ret = check_ast([get_resource_path('cannot_parse_ast.notpy')]) + ret = main([get_resource_path('cannot_parse_ast.notpy')]) assert ret == 1 def test_passing_file(): - ret = check_ast([__file__]) + ret = main([__file__]) assert ret == 0
diff --git a/tests/check_builtin_literals_test.py b/tests/check_builtin_literals_test.py index 86b79e3..d4ac30f 100644 --- a/tests/check_builtin_literals_test.py +++ b/tests/check_builtin_literals_test.py
@@ -5,7 +5,35 @@ from pre_commit_hooks.check_builtin_literals import BuiltinTypeCall from pre_commit_hooks.check_builtin_literals import BuiltinTypeVisitor from pre_commit_hooks.check_builtin_literals import main -from testing.util import get_resource_path + +BUILTIN_CONSTRUCTORS = '''\ +from six.moves import builtins + +c1 = complex() +d1 = dict() +f1 = float() +i1 = int() +l1 = list() +s1 = str() +t1 = tuple() + +c2 = builtins.complex() +d2 = builtins.dict() +f2 = builtins.float() +i2 = builtins.int() +l2 = builtins.list() +s2 = builtins.str() +t2 = builtins.tuple() +''' +BUILTIN_LITERALS = '''\ +c1 = 0j +d1 = {} +f1 = 0.0 +i1 = 0 +l1 = [] +s1 = '' +t1 = () +''' @pytest.fixture @@ -94,24 +122,26 @@ def test_ignore_constructors(): visitor = BuiltinTypeVisitor(ignore=('complex', 'dict', 'float', 'int', 'list', 'str', 'tuple')) - with open(get_resource_path('builtin_constructors.py'), 'rb') as f: - visitor.visit(ast.parse(f.read(), 'builtin_constructors.py')) + visitor.visit(ast.parse(BUILTIN_CONSTRUCTORS)) assert visitor.builtin_type_calls == [] -def test_failing_file(): - rc = main([get_resource_path('builtin_constructors.py')]) +def test_failing_file(tmpdir): + f = tmpdir.join('f.py') + f.write(BUILTIN_CONSTRUCTORS) + rc = main([f.strpath]) assert rc == 1 -def test_passing_file(): - rc = main([get_resource_path('builtin_literals.py')]) +def test_passing_file(tmpdir): + f = tmpdir.join('f.py') + f.write(BUILTIN_LITERALS) + rc = main([f.strpath]) assert rc == 0 -def test_failing_file_ignore_all(): - rc = main([ - '--ignore=complex,dict,float,int,list,str,tuple', - get_resource_path('builtin_constructors.py'), - ]) +def test_failing_file_ignore_all(tmpdir): + f = tmpdir.join('f.py') + f.write(BUILTIN_CONSTRUCTORS) + rc = main(['--ignore=complex,dict,float,int,list,str,tuple', f.strpath]) assert rc == 0
diff --git a/tests/check_json_test.py b/tests/check_json_test.py index 6ba26c1..6654ed1 100644 --- a/tests/check_json_test.py +++ b/tests/check_json_test.py
@@ -1,6 +1,6 @@ import pytest -from pre_commit_hooks.check_json import check_json +from pre_commit_hooks.check_json import main from testing.util import get_resource_path @@ -11,8 +11,8 @@ ('ok_json.json', 0), ), ) -def test_check_json(capsys, filename, expected_retval): - ret = check_json([get_resource_path(filename)]) +def test_main(capsys, filename, expected_retval): + ret = main([get_resource_path(filename)]) assert ret == expected_retval if expected_retval == 1: stdout, _ = capsys.readouterr()
diff --git a/tests/check_merge_conflict_test.py b/tests/check_merge_conflict_test.py index b04c70e..50e389c 100644 --- a/tests/check_merge_conflict_test.py +++ b/tests/check_merge_conflict_test.py
@@ -6,7 +6,7 @@ import pytest -from pre_commit_hooks.check_merge_conflict import detect_merge_conflict +from pre_commit_hooks.check_merge_conflict import main from pre_commit_hooks.util import cmd_output from testing.util import get_resource_path @@ -102,7 +102,7 @@ @pytest.mark.usefixtures('f1_is_a_conflict_file') def test_merge_conflicts_git(): - assert detect_merge_conflict(['f1']) == 1 + assert main(['f1']) == 1 @pytest.mark.parametrize( @@ -110,7 +110,7 @@ ) def test_merge_conflicts_failing(contents, repository_pending_merge): repository_pending_merge.join('f2').write_binary(contents) - assert detect_merge_conflict(['f2']) == 1 + assert main(['f2']) == 1 @pytest.mark.parametrize( @@ -118,22 +118,22 @@ ) def test_merge_conflicts_ok(contents, f1_is_a_conflict_file): f1_is_a_conflict_file.join('f1').write_binary(contents) - assert detect_merge_conflict(['f1']) == 0 + assert main(['f1']) == 0 @pytest.mark.usefixtures('f1_is_a_conflict_file') def test_ignores_binary_files(): shutil.copy(get_resource_path('img1.jpg'), 'f1') - assert detect_merge_conflict(['f1']) == 0 + assert main(['f1']) == 0 def test_does_not_care_when_not_in_a_merge(tmpdir): f = tmpdir.join('README.md') f.write_binary(b'problem\n=======\n') - assert detect_merge_conflict([str(f.realpath())]) == 0 + assert main([str(f.realpath())]) == 0 def test_care_when_assumed_merge(tmpdir): f = tmpdir.join('README.md') f.write_binary(b'problem\n=======\n') - assert detect_merge_conflict([str(f.realpath()), '--assume-in-merge']) == 1 + assert main([str(f.realpath()), '--assume-in-merge']) == 1
diff --git a/tests/check_symlinks_test.py b/tests/check_symlinks_test.py index 0414df5..ecbc7ae 100644 --- a/tests/check_symlinks_test.py +++ b/tests/check_symlinks_test.py
@@ -2,7 +2,7 @@ import pytest -from pre_commit_hooks.check_symlinks import check_symlinks +from pre_commit_hooks.check_symlinks import main xfail_symlink = pytest.mark.xfail(os.name == 'nt', reason='No symlink support') @@ -12,12 +12,12 @@ @pytest.mark.parametrize( ('dest', 'expected'), (('exists', 0), ('does-not-exist', 1)), ) -def test_check_symlinks(tmpdir, dest, expected): # pragma: no cover (symlinks) +def test_main(tmpdir, dest, expected): # pragma: no cover (symlinks) tmpdir.join('exists').ensure() symlink = tmpdir.join('symlink') symlink.mksymlinkto(tmpdir.join(dest)) - assert check_symlinks((symlink.strpath,)) == expected + assert main((symlink.strpath,)) == expected -def test_check_symlinks_normal_file(tmpdir): - assert check_symlinks((tmpdir.join('f').ensure().strpath,)) == 0 +def test_main_normal_file(tmpdir): + assert main((tmpdir.join('f').ensure().strpath,)) == 0
diff --git a/tests/check_xml_test.py b/tests/check_xml_test.py index 84e365d..357bad6 100644 --- a/tests/check_xml_test.py +++ b/tests/check_xml_test.py
@@ -1,6 +1,6 @@ import pytest -from pre_commit_hooks.check_xml import check_xml +from pre_commit_hooks.check_xml import main from testing.util import get_resource_path @@ -10,6 +10,6 @@ ('ok_xml.xml', 0), ), ) -def test_check_xml(filename, expected_retval): - ret = check_xml([get_resource_path(filename)]) +def test_main(filename, expected_retval): + ret = main([get_resource_path(filename)]) assert ret == expected_retval
diff --git a/tests/check_yaml_test.py b/tests/check_yaml_test.py index aa357f1..d267150 100644 --- a/tests/check_yaml_test.py +++ b/tests/check_yaml_test.py
@@ -3,7 +3,7 @@ import pytest -from pre_commit_hooks.check_yaml import check_yaml +from pre_commit_hooks.check_yaml import main from testing.util import get_resource_path @@ -13,29 +13,29 @@ ('ok_yaml.yaml', 0), ), ) -def test_check_yaml(filename, expected_retval): - ret = check_yaml([get_resource_path(filename)]) +def test_main(filename, expected_retval): + ret = main([get_resource_path(filename)]) assert ret == expected_retval -def test_check_yaml_allow_multiple_documents(tmpdir): +def test_main_allow_multiple_documents(tmpdir): f = tmpdir.join('test.yaml') f.write('---\nfoo\n---\nbar\n') # should fail without the setting - assert check_yaml((f.strpath,)) + assert main((f.strpath,)) # should pass when we allow multiple documents - assert not check_yaml(('--allow-multiple-documents', f.strpath)) + assert not main(('--allow-multiple-documents', f.strpath)) def test_fails_even_with_allow_multiple_documents(tmpdir): f = tmpdir.join('test.yaml') f.write('[') - assert check_yaml(('--allow-multiple-documents', f.strpath)) + assert main(('--allow-multiple-documents', f.strpath)) -def test_check_yaml_unsafe(tmpdir): +def test_main_unsafe(tmpdir): f = tmpdir.join('test.yaml') f.write( 'some_foo: !vault |\n' @@ -43,12 +43,12 @@ ' deadbeefdeadbeefdeadbeef\n', ) # should fail "safe" check - assert check_yaml((f.strpath,)) + assert main((f.strpath,)) # should pass when we allow unsafe documents - assert not check_yaml(('--unsafe', f.strpath)) + assert not main(('--unsafe', f.strpath)) -def test_check_yaml_unsafe_still_fails_on_syntax_errors(tmpdir): +def test_main_unsafe_still_fails_on_syntax_errors(tmpdir): f = tmpdir.join('test.yaml') f.write('[') - assert check_yaml(('--unsafe', f.strpath)) + assert main(('--unsafe', f.strpath))
diff --git a/tests/detect_private_key_test.py b/tests/detect_private_key_test.py index fdd63a2..9266f2b 100644 --- a/tests/detect_private_key_test.py +++ b/tests/detect_private_key_test.py
@@ -1,6 +1,6 @@ import pytest -from pre_commit_hooks.detect_private_key import detect_private_key +from pre_commit_hooks.detect_private_key import main # Input, expected return value TESTS = ( @@ -18,7 +18,7 @@ @pytest.mark.parametrize(('input_s', 'expected_retval'), TESTS) -def test_detect_private_key(input_s, expected_retval, tmpdir): +def test_main(input_s, expected_retval, tmpdir): path = tmpdir.join('file.txt') path.write_binary(input_s) - assert detect_private_key([path.strpath]) == expected_retval + assert main([path.strpath]) == expected_retval
diff --git a/tests/end_of_file_fixer_test.py b/tests/end_of_file_fixer_test.py index f8710af..7f644e7 100644 --- a/tests/end_of_file_fixer_test.py +++ b/tests/end_of_file_fixer_test.py
@@ -2,8 +2,8 @@ import pytest -from pre_commit_hooks.end_of_file_fixer import end_of_file_fixer from pre_commit_hooks.end_of_file_fixer import fix_file +from pre_commit_hooks.end_of_file_fixer import main # Input, expected return value, expected output @@ -35,7 +35,7 @@ path = tmpdir.join('file.txt') path.write_binary(input_s) - ret = end_of_file_fixer([path.strpath]) + ret = main([path.strpath]) file_output = path.read_binary() assert file_output == output
diff --git a/tests/no_commit_to_branch_test.py b/tests/no_commit_to_branch_test.py index c275bf7..e978ba2 100644 --- a/tests/no_commit_to_branch_test.py +++ b/tests/no_commit_to_branch_test.py
@@ -11,24 +11,24 @@ def test_other_branch(temp_git_dir): with temp_git_dir.as_cwd(): cmd_output('git', 'checkout', '-b', 'anotherbranch') - assert is_on_branch(('master',)) is False + assert is_on_branch({'master'}) is False def test_multi_branch(temp_git_dir): with temp_git_dir.as_cwd(): cmd_output('git', 'checkout', '-b', 'another/branch') - assert is_on_branch(('master',)) is False + assert is_on_branch({'master'}) is False def test_multi_branch_fail(temp_git_dir): with temp_git_dir.as_cwd(): cmd_output('git', 'checkout', '-b', 'another/branch') - assert is_on_branch(('another/branch',)) is True + assert is_on_branch({'another/branch'}) is True def test_master_branch(temp_git_dir): with temp_git_dir.as_cwd(): - assert is_on_branch(('master',)) is True + assert is_on_branch({'master'}) is True def test_main_branch_call(temp_git_dir):
diff --git a/tests/pretty_format_json_test.py b/tests/pretty_format_json_test.py index 7ce7e16..8d82d74 100644 --- a/tests/pretty_format_json_test.py +++ b/tests/pretty_format_json_test.py
@@ -3,8 +3,8 @@ import pytest from six import PY2 +from pre_commit_hooks.pretty_format_json import main from pre_commit_hooks.pretty_format_json import parse_num_to_int -from pre_commit_hooks.pretty_format_json import pretty_format_json from testing.util import get_resource_path @@ -23,8 +23,8 @@ ('pretty_formatted_json.json', 0), ), ) -def test_pretty_format_json(filename, expected_retval): - ret = pretty_format_json([get_resource_path(filename)]) +def test_main(filename, expected_retval): + ret = main([get_resource_path(filename)]) assert ret == expected_retval @@ -36,8 +36,8 @@ ('pretty_formatted_json.json', 0), ), ) -def test_unsorted_pretty_format_json(filename, expected_retval): - ret = pretty_format_json(['--no-sort-keys', get_resource_path(filename)]) +def test_unsorted_main(filename, expected_retval): + ret = main(['--no-sort-keys', get_resource_path(filename)]) assert ret == expected_retval @@ -51,17 +51,17 @@ ('tab_pretty_formatted_json.json', 0), ), ) -def test_tab_pretty_format_json(filename, expected_retval): # pragma: no cover - ret = pretty_format_json(['--indent', '\t', get_resource_path(filename)]) +def test_tab_main(filename, expected_retval): # pragma: no cover + ret = main(['--indent', '\t', get_resource_path(filename)]) assert ret == expected_retval -def test_non_ascii_pretty_format_json(): - ret = pretty_format_json(['--no-ensure-ascii', get_resource_path('non_ascii_pretty_formatted_json.json')]) +def test_non_ascii_main(): + ret = main(['--no-ensure-ascii', get_resource_path('non_ascii_pretty_formatted_json.json')]) assert ret == 0 -def test_autofix_pretty_format_json(tmpdir): +def test_autofix_main(tmpdir): srcfile = tmpdir.join('to_be_json_formatted.json') shutil.copyfile( get_resource_path('not_pretty_formatted_json.json'), @@ -69,30 +69,30 @@ ) # now launch the autofix on that file - ret = pretty_format_json(['--autofix', srcfile.strpath]) + ret = main(['--autofix', srcfile.strpath]) # it should have formatted it assert ret == 1 # file was formatted (shouldn't trigger linter again) - ret = pretty_format_json([srcfile.strpath]) + ret = main([srcfile.strpath]) assert ret == 0 def test_orderfile_get_pretty_format(): - ret = pretty_format_json(['--top-keys=alist', get_resource_path('pretty_formatted_json.json')]) + ret = main(['--top-keys=alist', get_resource_path('pretty_formatted_json.json')]) assert ret == 0 def test_not_orderfile_get_pretty_format(): - ret = pretty_format_json(['--top-keys=blah', get_resource_path('pretty_formatted_json.json')]) + ret = main(['--top-keys=blah', get_resource_path('pretty_formatted_json.json')]) assert ret == 1 def test_top_sorted_get_pretty_format(): - ret = pretty_format_json(['--top-keys=01-alist,alist', get_resource_path('top_sorted_json.json')]) + ret = main(['--top-keys=01-alist,alist', get_resource_path('top_sorted_json.json')]) assert ret == 0 -def test_badfile_pretty_format_json(): - ret = pretty_format_json([get_resource_path('ok_yaml.yaml')]) +def test_badfile_main(): + ret = main([get_resource_path('ok_yaml.yaml')]) assert ret == 1
diff --git a/tests/requirements_txt_fixer_test.py b/tests/requirements_txt_fixer_test.py index 437cebd..b3a7942 100644 --- a/tests/requirements_txt_fixer_test.py +++ b/tests/requirements_txt_fixer_test.py
@@ -1,7 +1,7 @@ import pytest from pre_commit_hooks.requirements_txt_fixer import FAIL -from pre_commit_hooks.requirements_txt_fixer import fix_requirements_txt +from pre_commit_hooks.requirements_txt_fixer import main from pre_commit_hooks.requirements_txt_fixer import PASS from pre_commit_hooks.requirements_txt_fixer import Requirement @@ -36,7 +36,7 @@ path = tmpdir.join('file.txt') path.write_binary(input_s) - output_retval = fix_requirements_txt([path.strpath]) + output_retval = main([path.strpath]) assert path.read_binary() == output assert output_retval == expected_retval @@ -44,7 +44,7 @@ def test_requirement_object(): top_of_file = Requirement() - top_of_file.comments.append('#foo') + top_of_file.comments.append(b'#foo') top_of_file.value = b'\n' requirement_foo = Requirement()
diff --git a/tests/sort_simple_yaml_test.py b/tests/sort_simple_yaml_test.py index 176d12f..72f5bec 100644 --- a/tests/sort_simple_yaml_test.py +++ b/tests/sort_simple_yaml_test.py
@@ -110,9 +110,9 @@ lines = ['# some comment', '"a": 42', 'b: 17', '', 'c: 19'] assert first_key(lines) == 'a": 42' - # no lines + # no lines (not a real situation) lines = [] - assert first_key(lines) is None + assert first_key(lines) == '' @pytest.mark.parametrize('bad_lines,good_lines,_', TEST_SORTS)
diff --git a/tests/tests_should_end_in_test_test.py b/tests/tests_should_end_in_test_test.py index dc686a5..4eb98e7 100644 --- a/tests/tests_should_end_in_test_test.py +++ b/tests/tests_should_end_in_test_test.py
@@ -1,36 +1,36 @@ -from pre_commit_hooks.tests_should_end_in_test import validate_files +from pre_commit_hooks.tests_should_end_in_test import main -def test_validate_files_all_pass(): - ret = validate_files(['foo_test.py', 'bar_test.py']) +def test_main_all_pass(): + ret = main(['foo_test.py', 'bar_test.py']) assert ret == 0 -def test_validate_files_one_fails(): - ret = validate_files(['not_test_ending.py', 'foo_test.py']) +def test_main_one_fails(): + ret = main(['not_test_ending.py', 'foo_test.py']) assert ret == 1 -def test_validate_files_django_all_pass(): - ret = validate_files(['--django', 'tests.py', 'test_foo.py', 'test_bar.py', 'tests/test_baz.py']) +def test_main_django_all_pass(): + ret = main(['--django', 'tests.py', 'test_foo.py', 'test_bar.py', 'tests/test_baz.py']) assert ret == 0 -def test_validate_files_django_one_fails(): - ret = validate_files(['--django', 'not_test_ending.py', 'test_foo.py']) +def test_main_django_one_fails(): + ret = main(['--django', 'not_test_ending.py', 'test_foo.py']) assert ret == 1 def test_validate_nested_files_django_one_fails(): - ret = validate_files(['--django', 'tests/not_test_ending.py', 'test_foo.py']) + ret = main(['--django', 'tests/not_test_ending.py', 'test_foo.py']) assert ret == 1 -def test_validate_files_not_django_fails(): - ret = validate_files(['foo_test.py', 'bar_test.py', 'test_baz.py']) +def test_main_not_django_fails(): + ret = main(['foo_test.py', 'bar_test.py', 'test_baz.py']) assert ret == 1 -def test_validate_files_django_fails(): - ret = validate_files(['--django', 'foo_test.py', 'test_bar.py', 'test_baz.py']) +def test_main_django_fails(): + ret = main(['--django', 'foo_test.py', 'test_bar.py', 'test_baz.py']) assert ret == 1
diff --git a/tox.ini b/tox.ini index c131e6f..d1e6a79 100644 --- a/tox.ini +++ b/tox.ini
@@ -1,6 +1,6 @@ [tox] # These should match the travis env list -envlist = py27,py36,py37,pypy +envlist = py27,py36,py37,pypy3 [testenv] deps = -rrequirements-dev.txt