Apply typing to all of pre-commit-hooks
diff --git a/pre_commit_hooks/autopep8_wrapper.py b/pre_commit_hooks/autopep8_wrapper.py index 9951924..8b69a04 100644 --- a/pre_commit_hooks/autopep8_wrapper.py +++ b/pre_commit_hooks/autopep8_wrapper.py
@@ -3,7 +3,7 @@ from __future__ import unicode_literals -def main(argv=None): +def main(): # type: () -> int raise SystemExit( 'autopep8-wrapper is deprecated. Instead use autopep8 directly via ' 'https://github.com/pre-commit/mirrors-autopep8',
diff --git a/pre_commit_hooks/check_added_large_files.py b/pre_commit_hooks/check_added_large_files.py index 2d06706..be39498 100644 --- a/pre_commit_hooks/check_added_large_files.py +++ b/pre_commit_hooks/check_added_large_files.py
@@ -7,13 +7,17 @@ import json import math import os +from typing import Iterable +from typing import Optional +from typing import Sequence +from typing import Set from pre_commit_hooks.util import added_files from pre_commit_hooks.util import CalledProcessError from pre_commit_hooks.util import cmd_output -def lfs_files(): +def lfs_files(): # type: () -> Set[str] try: # Introduced in git-lfs 2.2.0, first working in 2.2.1 lfs_ret = cmd_output('git', 'lfs', 'status', '--json') @@ -24,6 +28,7 @@ def find_large_added_files(filenames, maxkb): + # type: (Iterable[str], int) -> int # Find all added files that are also in the list of files pre-commit tells # us about filenames = (added_files() & set(filenames)) - lfs_files() @@ -38,7 +43,7 @@ return retv -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument( 'filenames', nargs='*',
diff --git a/pre_commit_hooks/check_ast.py b/pre_commit_hooks/check_ast.py index ded65e4..0df3540 100644 --- a/pre_commit_hooks/check_ast.py +++ b/pre_commit_hooks/check_ast.py
@@ -7,9 +7,11 @@ import platform import sys import traceback +from typing import Optional +from typing import Sequence -def check_ast(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*') args = parser.parse_args(argv) @@ -34,4 +36,4 @@ if __name__ == '__main__': - exit(check_ast()) + exit(main())
diff --git a/pre_commit_hooks/check_builtin_literals.py b/pre_commit_hooks/check_builtin_literals.py index 4a4b9ce..874c68c 100644 --- a/pre_commit_hooks/check_builtin_literals.py +++ b/pre_commit_hooks/check_builtin_literals.py
@@ -4,6 +4,10 @@ import ast import collections import sys +from typing import List +from typing import Optional +from typing import Sequence +from typing import Set BUILTIN_TYPES = { @@ -22,14 +26,17 @@ class BuiltinTypeVisitor(ast.NodeVisitor): def __init__(self, ignore=None, allow_dict_kwargs=True): - self.builtin_type_calls = [] + # type: (Optional[Sequence[str]], bool) -> None + self.builtin_type_calls = [] # type: List[BuiltinTypeCall] self.ignore = set(ignore) if ignore else set() self.allow_dict_kwargs = allow_dict_kwargs - def _check_dict_call(self, node): + def _check_dict_call(self, node): # type: (ast.Call) -> bool + return self.allow_dict_kwargs and (getattr(node, 'kwargs', None) or getattr(node, 'keywords', None)) - def visit_Call(self, node): + def visit_Call(self, node): # type: (ast.Call) -> None + if not isinstance(node.func, ast.Name): # Ignore functions that are object attributes (`foo.bar()`). # Assume that if the user calls `builtins.list()`, they know what @@ -47,6 +54,7 @@ def check_file_for_builtin_type_constructors(filename, ignore=None, allow_dict_kwargs=True): + # type: (str, Optional[Sequence[str]], bool) -> List[BuiltinTypeCall] with open(filename, 'rb') as f: tree = ast.parse(f.read(), filename=filename) visitor = BuiltinTypeVisitor(ignore=ignore, allow_dict_kwargs=allow_dict_kwargs) @@ -54,24 +62,22 @@ return visitor.builtin_type_calls -def parse_args(argv): - def parse_ignore(value): - return set(value.split(',')) +def parse_ignore(value): # type: (str) -> Set[str] + return set(value.split(',')) + +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*') parser.add_argument('--ignore', type=parse_ignore, default=set()) - allow_dict_kwargs = parser.add_mutually_exclusive_group(required=False) - allow_dict_kwargs.add_argument('--allow-dict-kwargs', action='store_true') - allow_dict_kwargs.add_argument('--no-allow-dict-kwargs', dest='allow_dict_kwargs', action='store_false') - allow_dict_kwargs.set_defaults(allow_dict_kwargs=True) + mutex = parser.add_mutually_exclusive_group(required=False) + mutex.add_argument('--allow-dict-kwargs', action='store_true') + mutex.add_argument('--no-allow-dict-kwargs', dest='allow_dict_kwargs', action='store_false') + mutex.set_defaults(allow_dict_kwargs=True) - return parser.parse_args(argv) + args = parser.parse_args(argv) - -def main(argv=None): - args = parse_args(argv) rc = 0 for filename in args.filenames: calls = check_file_for_builtin_type_constructors(
diff --git a/pre_commit_hooks/check_byte_order_marker.py b/pre_commit_hooks/check_byte_order_marker.py index 1541b30..10667c3 100644 --- a/pre_commit_hooks/check_byte_order_marker.py +++ b/pre_commit_hooks/check_byte_order_marker.py
@@ -3,9 +3,11 @@ from __future__ import unicode_literals import argparse +from typing import Optional +from typing import Sequence -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to check') args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/check_case_conflict.py b/pre_commit_hooks/check_case_conflict.py index 0f78296..e343d61 100644 --- a/pre_commit_hooks/check_case_conflict.py +++ b/pre_commit_hooks/check_case_conflict.py
@@ -3,16 +3,20 @@ from __future__ import unicode_literals import argparse +from typing import Iterable +from typing import Optional +from typing import Sequence +from typing import Set from pre_commit_hooks.util import added_files from pre_commit_hooks.util import cmd_output -def lower_set(iterable): +def lower_set(iterable): # type: (Iterable[str]) -> Set[str] return {x.lower() for x in iterable} -def find_conflicting_filenames(filenames): +def find_conflicting_filenames(filenames): # type: (Sequence[str]) -> int repo_files = set(cmd_output('git', 'ls-files').splitlines()) relevant_files = set(filenames) | added_files() repo_files -= relevant_files @@ -41,7 +45,7 @@ return retv -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument( 'filenames', nargs='*',
diff --git a/pre_commit_hooks/check_docstring_first.py b/pre_commit_hooks/check_docstring_first.py index 9988378..f4639f1 100644 --- a/pre_commit_hooks/check_docstring_first.py +++ b/pre_commit_hooks/check_docstring_first.py
@@ -5,6 +5,8 @@ import argparse import io import tokenize +from typing import Optional +from typing import Sequence NON_CODE_TOKENS = frozenset(( @@ -13,6 +15,7 @@ def check_docstring_first(src, filename='<unknown>'): + # type: (str, str) -> int """Returns nonzero if the source has what looks like a docstring that is not at the beginning of the source. @@ -50,7 +53,7 @@ return 0 -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*') args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/check_executables_have_shebangs.py b/pre_commit_hooks/check_executables_have_shebangs.py index 89ac6e5..c936a5d 100644 --- a/pre_commit_hooks/check_executables_have_shebangs.py +++ b/pre_commit_hooks/check_executables_have_shebangs.py
@@ -6,9 +6,11 @@ import argparse import pipes import sys +from typing import Optional +from typing import Sequence -def check_has_shebang(path): +def check_has_shebang(path): # type: (str) -> int with open(path, 'rb') as f: first_bytes = f.read(2) @@ -27,7 +29,7 @@ return 0 -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser(description=__doc__) parser.add_argument('filenames', nargs='*') args = parser.parse_args(argv) @@ -38,3 +40,7 @@ retv |= check_has_shebang(filename) return retv + + +if __name__ == '__main__': + exit(main())
diff --git a/pre_commit_hooks/check_json.py b/pre_commit_hooks/check_json.py index b403f4b..b939350 100644 --- a/pre_commit_hooks/check_json.py +++ b/pre_commit_hooks/check_json.py
@@ -4,9 +4,11 @@ import io import json import sys +from typing import Optional +from typing import Sequence -def check_json(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='JSON filenames to check.') args = parser.parse_args(argv) @@ -22,4 +24,4 @@ if __name__ == '__main__': - sys.exit(check_json()) + sys.exit(main())
diff --git a/pre_commit_hooks/check_merge_conflict.py b/pre_commit_hooks/check_merge_conflict.py index 6db5efe..74e4ae1 100644 --- a/pre_commit_hooks/check_merge_conflict.py +++ b/pre_commit_hooks/check_merge_conflict.py
@@ -2,6 +2,9 @@ import argparse import os.path +from typing import Optional +from typing import Sequence + CONFLICT_PATTERNS = [ b'<<<<<<< ', @@ -12,7 +15,7 @@ WARNING_MSG = 'Merge conflict string "{0}" found in {1}:{2}' -def is_in_merge(): +def is_in_merge(): # type: () -> int return ( os.path.exists(os.path.join('.git', 'MERGE_MSG')) and ( @@ -23,7 +26,7 @@ ) -def detect_merge_conflict(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*') parser.add_argument('--assume-in-merge', action='store_true') @@ -47,4 +50,4 @@ if __name__ == '__main__': - exit(detect_merge_conflict()) + exit(main())
diff --git a/pre_commit_hooks/check_symlinks.py b/pre_commit_hooks/check_symlinks.py index 010c871..736bf99 100644 --- a/pre_commit_hooks/check_symlinks.py +++ b/pre_commit_hooks/check_symlinks.py
@@ -4,9 +4,11 @@ import argparse import os.path +from typing import Optional +from typing import Sequence -def check_symlinks(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser(description='Checks for broken symlinks.') parser.add_argument('filenames', nargs='*', help='Filenames to check') args = parser.parse_args(argv) @@ -25,4 +27,4 @@ if __name__ == '__main__': - exit(check_symlinks()) + exit(main())
diff --git a/pre_commit_hooks/check_vcs_permalinks.py b/pre_commit_hooks/check_vcs_permalinks.py index f0dcf5b..f6e2a7d 100644 --- a/pre_commit_hooks/check_vcs_permalinks.py +++ b/pre_commit_hooks/check_vcs_permalinks.py
@@ -5,6 +5,8 @@ import argparse import re import sys +from typing import Optional +from typing import Sequence GITHUB_NON_PERMALINK = re.compile( @@ -12,7 +14,7 @@ ) -def _check_filename(filename): +def _check_filename(filename): # type: (str) -> int retv = 0 with open(filename, 'rb') as f: for i, line in enumerate(f, 1): @@ -24,7 +26,7 @@ return retv -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*') args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/check_xml.py b/pre_commit_hooks/check_xml.py index a4c11a5..66e10ba 100644 --- a/pre_commit_hooks/check_xml.py +++ b/pre_commit_hooks/check_xml.py
@@ -5,10 +5,12 @@ import argparse import io import sys -import xml.sax +import xml.sax.handler +from typing import Optional +from typing import Sequence -def check_xml(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='XML filenames to check.') args = parser.parse_args(argv) @@ -17,7 +19,7 @@ for filename in args.filenames: try: with io.open(filename, 'rb') as xml_file: - xml.sax.parse(xml_file, xml.sax.ContentHandler()) + xml.sax.parse(xml_file, xml.sax.handler.ContentHandler()) except xml.sax.SAXException as exc: print('{}: Failed to xml parse ({})'.format(filename, exc)) retval = 1 @@ -25,4 +27,4 @@ if __name__ == '__main__': - sys.exit(check_xml()) + sys.exit(main())
diff --git a/pre_commit_hooks/check_yaml.py b/pre_commit_hooks/check_yaml.py index 208737f..b638684 100644 --- a/pre_commit_hooks/check_yaml.py +++ b/pre_commit_hooks/check_yaml.py
@@ -3,22 +3,26 @@ import argparse import collections import sys +from typing import Any +from typing import Generator +from typing import Optional +from typing import Sequence import ruamel.yaml yaml = ruamel.yaml.YAML(typ='safe') -def _exhaust(gen): +def _exhaust(gen): # type: (Generator[str, None, None]) -> None for _ in gen: pass -def _parse_unsafe(*args, **kwargs): +def _parse_unsafe(*args, **kwargs): # type: (*Any, **Any) -> None _exhaust(yaml.parse(*args, **kwargs)) -def _load_all(*args, **kwargs): +def _load_all(*args, **kwargs): # type: (*Any, **Any) -> None _exhaust(yaml.load_all(*args, **kwargs)) @@ -31,7 +35,7 @@ } -def check_yaml(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument( '-m', '--multi', '--allow-multiple-documents', action='store_true', @@ -63,4 +67,4 @@ if __name__ == '__main__': - sys.exit(check_yaml()) + sys.exit(main())
diff --git a/pre_commit_hooks/debug_statement_hook.py b/pre_commit_hooks/debug_statement_hook.py index 5d32277..02dd3b2 100644 --- a/pre_commit_hooks/debug_statement_hook.py +++ b/pre_commit_hooks/debug_statement_hook.py
@@ -5,6 +5,9 @@ import ast import collections import traceback +from typing import List +from typing import Optional +from typing import Sequence DEBUG_STATEMENTS = {'pdb', 'ipdb', 'pudb', 'q', 'rdb'} @@ -12,21 +15,21 @@ class DebugStatementParser(ast.NodeVisitor): - def __init__(self): - self.breakpoints = [] + def __init__(self): # type: () -> None + self.breakpoints = [] # type: List[Debug] - def visit_Import(self, node): + def visit_Import(self, node): # type: (ast.Import) -> None for name in node.names: if name.name in DEBUG_STATEMENTS: st = Debug(node.lineno, node.col_offset, name.name, 'imported') self.breakpoints.append(st) - def visit_ImportFrom(self, node): + def visit_ImportFrom(self, node): # type: (ast.ImportFrom) -> None if node.module in DEBUG_STATEMENTS: st = Debug(node.lineno, node.col_offset, node.module, 'imported') self.breakpoints.append(st) - def visit_Call(self, node): + def visit_Call(self, node): # type: (ast.Call) -> None """python3.7+ breakpoint()""" if isinstance(node.func, ast.Name) and node.func.id == 'breakpoint': st = Debug(node.lineno, node.col_offset, node.func.id, 'called') @@ -34,7 +37,7 @@ self.generic_visit(node) -def check_file(filename): +def check_file(filename): # type: (str) -> int try: with open(filename, 'rb') as f: ast_obj = ast.parse(f.read(), filename=filename) @@ -58,7 +61,7 @@ return int(bool(visitor.breakpoints)) -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to run') args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/detect_aws_credentials.py b/pre_commit_hooks/detect_aws_credentials.py index ecd9d40..3c87d11 100644 --- a/pre_commit_hooks/detect_aws_credentials.py +++ b/pre_commit_hooks/detect_aws_credentials.py
@@ -3,11 +3,16 @@ import argparse import os +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Set from six.moves import configparser -def get_aws_credential_files_from_env(): +def get_aws_credential_files_from_env(): # type: () -> Set[str] """Extract credential file paths from environment variables.""" files = set() for env_var in ( @@ -19,7 +24,7 @@ return files -def get_aws_secrets_from_env(): +def get_aws_secrets_from_env(): # type: () -> Set[str] """Extract AWS secrets from environment variables.""" keys = set() for env_var in ( @@ -30,7 +35,7 @@ return keys -def get_aws_secrets_from_file(credentials_file): +def get_aws_secrets_from_file(credentials_file): # type: (str) -> Set[str] """Extract AWS secrets from configuration files. Read an ini-style configuration file and return a set with all found AWS @@ -62,6 +67,7 @@ def check_file_for_aws_keys(filenames, keys): + # type: (Sequence[str], Set[str]) -> List[Dict[str, str]] """Check if files contain AWS secrets. Return a list of all files containing AWS secrets and keys found, with all @@ -82,7 +88,7 @@ return bad_files -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='+', help='Filenames to run') parser.add_argument( @@ -111,7 +117,7 @@ # of files to to gather AWS secrets from. credential_files |= get_aws_credential_files_from_env() - keys = set() + keys = set() # type: Set[str] for credential_file in credential_files: keys |= get_aws_secrets_from_file(credential_file)
diff --git a/pre_commit_hooks/detect_private_key.py b/pre_commit_hooks/detect_private_key.py index c8ee961..d31957d 100644 --- a/pre_commit_hooks/detect_private_key.py +++ b/pre_commit_hooks/detect_private_key.py
@@ -2,6 +2,8 @@ import argparse import sys +from typing import Optional +from typing import Sequence BLACKLIST = [ b'BEGIN RSA PRIVATE KEY', @@ -15,7 +17,7 @@ ] -def detect_private_key(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to check') args = parser.parse_args(argv) @@ -37,4 +39,4 @@ if __name__ == '__main__': - sys.exit(detect_private_key()) + sys.exit(main())
diff --git a/pre_commit_hooks/end_of_file_fixer.py b/pre_commit_hooks/end_of_file_fixer.py index 5ab1b7b..4e77c94 100644 --- a/pre_commit_hooks/end_of_file_fixer.py +++ b/pre_commit_hooks/end_of_file_fixer.py
@@ -4,9 +4,12 @@ import argparse import os import sys +from typing import IO +from typing import Optional +from typing import Sequence -def fix_file(file_obj): +def fix_file(file_obj): # type: (IO[bytes]) -> int # Test for newline at end of file # Empty files will throw IOError here try: @@ -49,7 +52,7 @@ return 0 -def end_of_file_fixer(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to fix') args = parser.parse_args(argv) @@ -68,4 +71,4 @@ if __name__ == '__main__': - sys.exit(end_of_file_fixer()) + sys.exit(main())
diff --git a/pre_commit_hooks/file_contents_sorter.py b/pre_commit_hooks/file_contents_sorter.py index fe7f7ee..6f13c98 100644 --- a/pre_commit_hooks/file_contents_sorter.py +++ b/pre_commit_hooks/file_contents_sorter.py
@@ -12,12 +12,15 @@ from __future__ import print_function import argparse +from typing import IO +from typing import Optional +from typing import Sequence PASS = 0 FAIL = 1 -def sort_file_contents(f): +def sort_file_contents(f): # type: (IO[bytes]) -> int before = list(f) after = sorted([line.strip(b'\n\r') for line in before if line.strip()]) @@ -33,7 +36,7 @@ return FAIL -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='+', help='Files to sort') args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/fix_encoding_pragma.py b/pre_commit_hooks/fix_encoding_pragma.py index 3bf234e..b0b5c8e 100644 --- a/pre_commit_hooks/fix_encoding_pragma.py +++ b/pre_commit_hooks/fix_encoding_pragma.py
@@ -4,11 +4,15 @@ import argparse import collections +from typing import IO +from typing import Optional +from typing import Sequence +from typing import Union DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-\n' -def has_coding(line): +def has_coding(line): # type: (bytes) -> bool if not line.strip(): return False return ( @@ -33,15 +37,16 @@ __slots__ = () @property - def has_any_pragma(self): + def has_any_pragma(self): # type: () -> bool return self.pragma_status is not False - def is_expected_pragma(self, remove): + def is_expected_pragma(self, remove): # type: (bool) -> bool expected_pragma_status = not remove return self.pragma_status is expected_pragma_status def _get_expected_contents(first_line, second_line, rest, expected_pragma): + # type: (bytes, bytes, bytes, bytes) -> ExpectedContents if first_line.startswith(b'#!'): shebang = first_line potential_coding = second_line @@ -51,7 +56,7 @@ rest = second_line + rest if potential_coding == expected_pragma: - pragma_status = True + pragma_status = True # type: Optional[bool] elif has_coding(potential_coding): pragma_status = None else: @@ -64,6 +69,7 @@ def fix_encoding_pragma(f, remove=False, expected_pragma=DEFAULT_PRAGMA): + # type: (IO[bytes], bool, bytes) -> int expected = _get_expected_contents( f.readline(), f.readline(), f.read(), expected_pragma, ) @@ -93,17 +99,17 @@ return 1 -def _normalize_pragma(pragma): +def _normalize_pragma(pragma): # type: (Union[bytes, str]) -> bytes if not isinstance(pragma, bytes): pragma = pragma.encode('UTF-8') return pragma.rstrip() + b'\n' -def _to_disp(pragma): +def _to_disp(pragma): # type: (bytes) -> str return pragma.decode().rstrip() -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser('Fixes the encoding pragma of python files') parser.add_argument('filenames', nargs='*', help='Filenames to fix') parser.add_argument(
diff --git a/pre_commit_hooks/forbid_new_submodules.py b/pre_commit_hooks/forbid_new_submodules.py index c9464cf..bdbd6f7 100644 --- a/pre_commit_hooks/forbid_new_submodules.py +++ b/pre_commit_hooks/forbid_new_submodules.py
@@ -2,10 +2,13 @@ from __future__ import print_function from __future__ import unicode_literals +from typing import Optional +from typing import Sequence + from pre_commit_hooks.util import cmd_output -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int # `argv` is ignored, pre-commit will send us a list of files that we # don't care about added_diff = cmd_output(
diff --git a/pre_commit_hooks/mixed_line_ending.py b/pre_commit_hooks/mixed_line_ending.py index e35a65c..90aef03 100644 --- a/pre_commit_hooks/mixed_line_ending.py +++ b/pre_commit_hooks/mixed_line_ending.py
@@ -4,6 +4,9 @@ import argparse import collections +from typing import Dict +from typing import Optional +from typing import Sequence CRLF = b'\r\n' @@ -14,7 +17,7 @@ FIX_TO_LINE_ENDING = {'cr': CR, 'crlf': CRLF, 'lf': LF} -def _fix(filename, contents, ending): +def _fix(filename, contents, ending): # type: (str, bytes, bytes) -> None new_contents = b''.join( line.rstrip(b'\r\n') + ending for line in contents.splitlines(True) ) @@ -22,11 +25,11 @@ f.write(new_contents) -def fix_filename(filename, fix): +def fix_filename(filename, fix): # type: (str, str) -> int with open(filename, 'rb') as f: contents = f.read() - counts = collections.defaultdict(int) + counts = collections.defaultdict(int) # type: Dict[bytes, int] for line in contents.splitlines(True): for ending in ALL_ENDINGS: @@ -63,7 +66,7 @@ return other_endings -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument( '-f', '--fix',
diff --git a/pre_commit_hooks/no_commit_to_branch.py b/pre_commit_hooks/no_commit_to_branch.py index fdd146b..6b68c91 100644 --- a/pre_commit_hooks/no_commit_to_branch.py +++ b/pre_commit_hooks/no_commit_to_branch.py
@@ -1,12 +1,15 @@ from __future__ import print_function import argparse +from typing import Optional +from typing import Sequence +from typing import Set from pre_commit_hooks.util import CalledProcessError from pre_commit_hooks.util import cmd_output -def is_on_branch(protected): +def is_on_branch(protected): # type: (Set[str]) -> bool try: branch = cmd_output('git', 'symbolic-ref', 'HEAD') except CalledProcessError: @@ -15,7 +18,7 @@ return '/'.join(chunks[2:]) in protected -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument( '-b', '--branch', action='append',
diff --git a/pre_commit_hooks/pretty_format_json.py b/pre_commit_hooks/pretty_format_json.py index 363037e..de7f8d7 100644 --- a/pre_commit_hooks/pretty_format_json.py +++ b/pre_commit_hooks/pretty_format_json.py
@@ -5,12 +5,20 @@ import json import sys from collections import OrderedDict +from typing import List +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union from six import text_type -def _get_pretty_format(contents, indent, ensure_ascii=True, sort_keys=True, top_keys=[]): +def _get_pretty_format(contents, indent, ensure_ascii=True, sort_keys=True, top_keys=()): + # type: (str, str, bool, bool, Sequence[str]) -> str def pairs_first(pairs): + # type: (Sequence[Tuple[str, str]]) -> Mapping[str, str] before = [pair for pair in pairs if pair[0] in top_keys] before = sorted(before, key=lambda x: top_keys.index(x[0])) after = [pair for pair in pairs if pair[0] not in top_keys] @@ -27,13 +35,13 @@ return text_type(json_pretty) + '\n' -def _autofix(filename, new_contents): +def _autofix(filename, new_contents): # type: (str, str) -> None print('Fixing file {}'.format(filename)) with io.open(filename, 'w', encoding='UTF-8') as f: f.write(new_contents) -def parse_num_to_int(s): +def parse_num_to_int(s): # type: (str) -> Union[int, str] """Convert string numbers to int, leaving strings as is.""" try: return int(s) @@ -41,11 +49,11 @@ return s -def parse_topkeys(s): +def parse_topkeys(s): # type: (str) -> List[str] return s.split(',') -def pretty_format_json(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument( '--autofix', @@ -117,4 +125,4 @@ if __name__ == '__main__': - sys.exit(pretty_format_json()) + sys.exit(main())
diff --git a/pre_commit_hooks/requirements_txt_fixer.py b/pre_commit_hooks/requirements_txt_fixer.py index 6dcf8d0..3f85a17 100644 --- a/pre_commit_hooks/requirements_txt_fixer.py +++ b/pre_commit_hooks/requirements_txt_fixer.py
@@ -1,6 +1,10 @@ from __future__ import print_function import argparse +from typing import IO +from typing import List +from typing import Optional +from typing import Sequence PASS = 0 @@ -9,21 +13,23 @@ class Requirement(object): - def __init__(self): + def __init__(self): # type: () -> None super(Requirement, self).__init__() - self.value = None - self.comments = [] + self.value = None # type: Optional[bytes] + self.comments = [] # type: List[bytes] @property - def name(self): + def name(self): # type: () -> bytes + assert self.value is not None, self.value if self.value.startswith(b'-e '): return self.value.lower().partition(b'=')[-1] return self.value.lower().partition(b'==')[0] - def __lt__(self, requirement): + def __lt__(self, requirement): # type: (Requirement) -> int # \n means top of file comment, so always return True, # otherwise just do a string comparison with value. + assert self.value is not None, self.value if self.value == b'\n': return True elif requirement.value == b'\n': @@ -32,10 +38,10 @@ return self.name < requirement.name -def fix_requirements(f): - requirements = [] +def fix_requirements(f): # type: (IO[bytes]) -> int + requirements = [] # type: List[Requirement] before = tuple(f) - after = [] + after = [] # type: List[bytes] before_string = b''.join(before) @@ -46,6 +52,7 @@ for line in before: # If the most recent requirement object has a value, then it's # time to start building the next requirement object. + if not len(requirements) or requirements[-1].value is not None: requirements.append(Requirement()) @@ -78,6 +85,7 @@ for requirement in sorted(requirements): after.extend(requirement.comments) + assert requirement.value, requirement.value after.append(requirement.value) after.extend(rest) @@ -92,7 +100,7 @@ return FAIL -def fix_requirements_txt(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to fix') args = parser.parse_args(argv) @@ -109,3 +117,7 @@ retv |= ret_for_file return retv + + +if __name__ == '__main__': + exit(main())
diff --git a/pre_commit_hooks/sort_simple_yaml.py b/pre_commit_hooks/sort_simple_yaml.py index 7afae91..3c8ef16 100755 --- a/pre_commit_hooks/sort_simple_yaml.py +++ b/pre_commit_hooks/sort_simple_yaml.py
@@ -21,12 +21,15 @@ from __future__ import print_function import argparse +from typing import List +from typing import Optional +from typing import Sequence QUOTES = ["'", '"'] -def sort(lines): +def sort(lines): # type: (List[str]) -> List[str] """Sort a YAML file in alphabetical order, keeping blocks together. :param lines: array of strings (without newlines) @@ -44,7 +47,7 @@ return new_lines -def parse_block(lines, header=False): +def parse_block(lines, header=False): # type: (List[str], bool) -> List[str] """Parse and return a single block, popping off the start of `lines`. If parsing a header block, we stop after we reach a line that is not a @@ -60,7 +63,7 @@ return block_lines -def parse_blocks(lines): +def parse_blocks(lines): # type: (List[str]) -> List[List[str]] """Parse and return all possible blocks, popping off the start of `lines`. :param lines: list of lines @@ -77,7 +80,7 @@ return blocks -def first_key(lines): +def first_key(lines): # type: (List[str]) -> str """Returns a string representing the sort key of a block. The sort key is the first YAML key we encounter, ignoring comments, and @@ -95,9 +98,11 @@ if any(line.startswith(quote) for quote in QUOTES): return line[1:] return line + else: + return '' # not actually reached in reality -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to fix') args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/string_fixer.py b/pre_commit_hooks/string_fixer.py index c432682..a5ea1ea 100644 --- a/pre_commit_hooks/string_fixer.py +++ b/pre_commit_hooks/string_fixer.py
@@ -4,34 +4,39 @@ import argparse import io +import re import tokenize +from typing import List +from typing import Optional +from typing import Sequence + +START_QUOTE_RE = re.compile('^[a-zA-Z]*"') -double_quote_starts = tuple(s for s in tokenize.single_quoted if '"' in s) - - -def handle_match(token_text): +def handle_match(token_text): # type: (str) -> str if '"""' in token_text or "'''" in token_text: return token_text - for double_quote_start in double_quote_starts: - if token_text.startswith(double_quote_start): - meat = token_text[len(double_quote_start):-1] - if '"' in meat or "'" in meat: - break - return double_quote_start.replace('"', "'") + meat + "'" - return token_text + match = START_QUOTE_RE.match(token_text) + if match is not None: + meat = token_text[match.end():-1] + if '"' in meat or "'" in meat: + return token_text + else: + return match.group().replace('"', "'") + meat + "'" + else: + return token_text -def get_line_offsets_by_line_no(src): +def get_line_offsets_by_line_no(src): # type: (str) -> List[int] # Padded so we can index with line number - offsets = [None, 0] + offsets = [-1, 0] for line in src.splitlines(): offsets.append(offsets[-1] + len(line) + 1) return offsets -def fix_strings(filename): +def fix_strings(filename): # type: (str) -> int with io.open(filename, encoding='UTF-8') as f: contents = f.read() line_offsets = get_line_offsets_by_line_no(contents) @@ -60,7 +65,7 @@ return 0 -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to fix') args = parser.parse_args(argv) @@ -74,3 +79,7 @@ retv |= return_value return retv + + +if __name__ == '__main__': + exit(main())
diff --git a/pre_commit_hooks/tests_should_end_in_test.py b/pre_commit_hooks/tests_should_end_in_test.py index 9bea20d..7a1e7c0 100644 --- a/pre_commit_hooks/tests_should_end_in_test.py +++ b/pre_commit_hooks/tests_should_end_in_test.py
@@ -1,12 +1,14 @@ from __future__ import print_function import argparse +import os.path import re import sys -from os.path import basename +from typing import Optional +from typing import Sequence -def validate_files(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*') parser.add_argument( @@ -18,7 +20,7 @@ retcode = 0 test_name_pattern = 'test.*.py' if args.django else '.*_test.py' for filename in args.filenames: - base = basename(filename) + base = os.path.basename(filename) if ( not re.match(test_name_pattern, base) and not base == '__init__.py' and @@ -35,4 +37,4 @@ if __name__ == '__main__': - sys.exit(validate_files()) + sys.exit(main())
diff --git a/pre_commit_hooks/trailing_whitespace_fixer.py b/pre_commit_hooks/trailing_whitespace_fixer.py index 1b54fbd..4fe7975 100644 --- a/pre_commit_hooks/trailing_whitespace_fixer.py +++ b/pre_commit_hooks/trailing_whitespace_fixer.py
@@ -3,9 +3,11 @@ import argparse import os import sys +from typing import Optional +from typing import Sequence -def _fix_file(filename, is_markdown): +def _fix_file(filename, is_markdown): # type: (str, bool) -> bool with open(filename, mode='rb') as file_processed: lines = file_processed.readlines() newlines = [_process_line(line, is_markdown) for line in lines] @@ -18,7 +20,7 @@ return False -def _process_line(line, is_markdown): +def _process_line(line, is_markdown): # type: (bytes, bool) -> bytes if line[-2:] == b'\r\n': eol = b'\r\n' elif line[-1:] == b'\n': @@ -31,7 +33,7 @@ return line.rstrip() + eol -def main(argv=None): +def main(argv=None): # type: (Optional[Sequence[str]]) -> int parser = argparse.ArgumentParser() parser.add_argument( '--no-markdown-linebreak-ext',
diff --git a/pre_commit_hooks/util.py b/pre_commit_hooks/util.py index 269b553..5d1d11b 100644 --- a/pre_commit_hooks/util.py +++ b/pre_commit_hooks/util.py
@@ -3,23 +3,25 @@ from __future__ import unicode_literals import subprocess +from typing import Any +from typing import Set class CalledProcessError(RuntimeError): pass -def added_files(): +def added_files(): # type: () -> Set[str] return set(cmd_output( 'git', 'diff', '--staged', '--name-only', '--diff-filter=A', ).splitlines()) -def cmd_output(*cmd, **kwargs): +def cmd_output(*cmd, **kwargs): # type: (*str, **Any) -> str retcode = kwargs.pop('retcode', 0) - popen_kwargs = {'stdout': subprocess.PIPE, 'stderr': subprocess.PIPE} - popen_kwargs.update(kwargs) - proc = subprocess.Popen(cmd, **popen_kwargs) + kwargs.setdefault('stdout', subprocess.PIPE) + kwargs.setdefault('stderr', subprocess.PIPE) + proc = subprocess.Popen(cmd, **kwargs) stdout, stderr = proc.communicate() stdout = stdout.decode('UTF-8') if stderr is not None: