Apply typing to all of pre-commit-hooks
diff --git a/pre_commit_hooks/autopep8_wrapper.py b/pre_commit_hooks/autopep8_wrapper.py
index 9951924..8b69a04 100644
--- a/pre_commit_hooks/autopep8_wrapper.py
+++ b/pre_commit_hooks/autopep8_wrapper.py
@@ -3,7 +3,7 @@
from __future__ import unicode_literals
-def main(argv=None):
+def main(): # type: () -> int
raise SystemExit(
'autopep8-wrapper is deprecated. Instead use autopep8 directly via '
'https://github.com/pre-commit/mirrors-autopep8',
diff --git a/pre_commit_hooks/check_added_large_files.py b/pre_commit_hooks/check_added_large_files.py
index 2d06706..be39498 100644
--- a/pre_commit_hooks/check_added_large_files.py
+++ b/pre_commit_hooks/check_added_large_files.py
@@ -7,13 +7,17 @@
import json
import math
import os
+from typing import Iterable
+from typing import Optional
+from typing import Sequence
+from typing import Set
from pre_commit_hooks.util import added_files
from pre_commit_hooks.util import CalledProcessError
from pre_commit_hooks.util import cmd_output
-def lfs_files():
+def lfs_files(): # type: () -> Set[str]
try:
# Introduced in git-lfs 2.2.0, first working in 2.2.1
lfs_ret = cmd_output('git', 'lfs', 'status', '--json')
@@ -24,6 +28,7 @@
def find_large_added_files(filenames, maxkb):
+ # type: (Iterable[str], int) -> int
# Find all added files that are also in the list of files pre-commit tells
# us about
filenames = (added_files() & set(filenames)) - lfs_files()
@@ -38,7 +43,7 @@
return retv
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument(
'filenames', nargs='*',
diff --git a/pre_commit_hooks/check_ast.py b/pre_commit_hooks/check_ast.py
index ded65e4..0df3540 100644
--- a/pre_commit_hooks/check_ast.py
+++ b/pre_commit_hooks/check_ast.py
@@ -7,9 +7,11 @@
import platform
import sys
import traceback
+from typing import Optional
+from typing import Sequence
-def check_ast(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
args = parser.parse_args(argv)
@@ -34,4 +36,4 @@
if __name__ == '__main__':
- exit(check_ast())
+ exit(main())
diff --git a/pre_commit_hooks/check_builtin_literals.py b/pre_commit_hooks/check_builtin_literals.py
index 4a4b9ce..874c68c 100644
--- a/pre_commit_hooks/check_builtin_literals.py
+++ b/pre_commit_hooks/check_builtin_literals.py
@@ -4,6 +4,10 @@
import ast
import collections
import sys
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Set
BUILTIN_TYPES = {
@@ -22,14 +26,17 @@
class BuiltinTypeVisitor(ast.NodeVisitor):
def __init__(self, ignore=None, allow_dict_kwargs=True):
- self.builtin_type_calls = []
+ # type: (Optional[Sequence[str]], bool) -> None
+ self.builtin_type_calls = [] # type: List[BuiltinTypeCall]
self.ignore = set(ignore) if ignore else set()
self.allow_dict_kwargs = allow_dict_kwargs
- def _check_dict_call(self, node):
+ def _check_dict_call(self, node): # type: (ast.Call) -> bool
+
return self.allow_dict_kwargs and (getattr(node, 'kwargs', None) or getattr(node, 'keywords', None))
- def visit_Call(self, node):
+ def visit_Call(self, node): # type: (ast.Call) -> None
+
if not isinstance(node.func, ast.Name):
# Ignore functions that are object attributes (`foo.bar()`).
# Assume that if the user calls `builtins.list()`, they know what
@@ -47,6 +54,7 @@
def check_file_for_builtin_type_constructors(filename, ignore=None, allow_dict_kwargs=True):
+ # type: (str, Optional[Sequence[str]], bool) -> List[BuiltinTypeCall]
with open(filename, 'rb') as f:
tree = ast.parse(f.read(), filename=filename)
visitor = BuiltinTypeVisitor(ignore=ignore, allow_dict_kwargs=allow_dict_kwargs)
@@ -54,24 +62,22 @@
return visitor.builtin_type_calls
-def parse_args(argv):
- def parse_ignore(value):
- return set(value.split(','))
+def parse_ignore(value): # type: (str) -> Set[str]
+ return set(value.split(','))
+
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
parser.add_argument('--ignore', type=parse_ignore, default=set())
- allow_dict_kwargs = parser.add_mutually_exclusive_group(required=False)
- allow_dict_kwargs.add_argument('--allow-dict-kwargs', action='store_true')
- allow_dict_kwargs.add_argument('--no-allow-dict-kwargs', dest='allow_dict_kwargs', action='store_false')
- allow_dict_kwargs.set_defaults(allow_dict_kwargs=True)
+ mutex = parser.add_mutually_exclusive_group(required=False)
+ mutex.add_argument('--allow-dict-kwargs', action='store_true')
+ mutex.add_argument('--no-allow-dict-kwargs', dest='allow_dict_kwargs', action='store_false')
+ mutex.set_defaults(allow_dict_kwargs=True)
- return parser.parse_args(argv)
+ args = parser.parse_args(argv)
-
-def main(argv=None):
- args = parse_args(argv)
rc = 0
for filename in args.filenames:
calls = check_file_for_builtin_type_constructors(
diff --git a/pre_commit_hooks/check_byte_order_marker.py b/pre_commit_hooks/check_byte_order_marker.py
index 1541b30..10667c3 100644
--- a/pre_commit_hooks/check_byte_order_marker.py
+++ b/pre_commit_hooks/check_byte_order_marker.py
@@ -3,9 +3,11 @@
from __future__ import unicode_literals
import argparse
+from typing import Optional
+from typing import Sequence
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to check')
args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/check_case_conflict.py b/pre_commit_hooks/check_case_conflict.py
index 0f78296..e343d61 100644
--- a/pre_commit_hooks/check_case_conflict.py
+++ b/pre_commit_hooks/check_case_conflict.py
@@ -3,16 +3,20 @@
from __future__ import unicode_literals
import argparse
+from typing import Iterable
+from typing import Optional
+from typing import Sequence
+from typing import Set
from pre_commit_hooks.util import added_files
from pre_commit_hooks.util import cmd_output
-def lower_set(iterable):
+def lower_set(iterable): # type: (Iterable[str]) -> Set[str]
return {x.lower() for x in iterable}
-def find_conflicting_filenames(filenames):
+def find_conflicting_filenames(filenames): # type: (Sequence[str]) -> int
repo_files = set(cmd_output('git', 'ls-files').splitlines())
relevant_files = set(filenames) | added_files()
repo_files -= relevant_files
@@ -41,7 +45,7 @@
return retv
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument(
'filenames', nargs='*',
diff --git a/pre_commit_hooks/check_docstring_first.py b/pre_commit_hooks/check_docstring_first.py
index 9988378..f4639f1 100644
--- a/pre_commit_hooks/check_docstring_first.py
+++ b/pre_commit_hooks/check_docstring_first.py
@@ -5,6 +5,8 @@
import argparse
import io
import tokenize
+from typing import Optional
+from typing import Sequence
NON_CODE_TOKENS = frozenset((
@@ -13,6 +15,7 @@
def check_docstring_first(src, filename='<unknown>'):
+ # type: (str, str) -> int
"""Returns nonzero if the source has what looks like a docstring that is
not at the beginning of the source.
@@ -50,7 +53,7 @@
return 0
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/check_executables_have_shebangs.py b/pre_commit_hooks/check_executables_have_shebangs.py
index 89ac6e5..c936a5d 100644
--- a/pre_commit_hooks/check_executables_have_shebangs.py
+++ b/pre_commit_hooks/check_executables_have_shebangs.py
@@ -6,9 +6,11 @@
import argparse
import pipes
import sys
+from typing import Optional
+from typing import Sequence
-def check_has_shebang(path):
+def check_has_shebang(path): # type: (str) -> int
with open(path, 'rb') as f:
first_bytes = f.read(2)
@@ -27,7 +29,7 @@
return 0
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('filenames', nargs='*')
args = parser.parse_args(argv)
@@ -38,3 +40,7 @@
retv |= check_has_shebang(filename)
return retv
+
+
+if __name__ == '__main__':
+ exit(main())
diff --git a/pre_commit_hooks/check_json.py b/pre_commit_hooks/check_json.py
index b403f4b..b939350 100644
--- a/pre_commit_hooks/check_json.py
+++ b/pre_commit_hooks/check_json.py
@@ -4,9 +4,11 @@
import io
import json
import sys
+from typing import Optional
+from typing import Sequence
-def check_json(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='JSON filenames to check.')
args = parser.parse_args(argv)
@@ -22,4 +24,4 @@
if __name__ == '__main__':
- sys.exit(check_json())
+ sys.exit(main())
diff --git a/pre_commit_hooks/check_merge_conflict.py b/pre_commit_hooks/check_merge_conflict.py
index 6db5efe..74e4ae1 100644
--- a/pre_commit_hooks/check_merge_conflict.py
+++ b/pre_commit_hooks/check_merge_conflict.py
@@ -2,6 +2,9 @@
import argparse
import os.path
+from typing import Optional
+from typing import Sequence
+
CONFLICT_PATTERNS = [
b'<<<<<<< ',
@@ -12,7 +15,7 @@
WARNING_MSG = 'Merge conflict string "{0}" found in {1}:{2}'
-def is_in_merge():
+def is_in_merge(): # type: () -> int
return (
os.path.exists(os.path.join('.git', 'MERGE_MSG')) and
(
@@ -23,7 +26,7 @@
)
-def detect_merge_conflict(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
parser.add_argument('--assume-in-merge', action='store_true')
@@ -47,4 +50,4 @@
if __name__ == '__main__':
- exit(detect_merge_conflict())
+ exit(main())
diff --git a/pre_commit_hooks/check_symlinks.py b/pre_commit_hooks/check_symlinks.py
index 010c871..736bf99 100644
--- a/pre_commit_hooks/check_symlinks.py
+++ b/pre_commit_hooks/check_symlinks.py
@@ -4,9 +4,11 @@
import argparse
import os.path
+from typing import Optional
+from typing import Sequence
-def check_symlinks(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser(description='Checks for broken symlinks.')
parser.add_argument('filenames', nargs='*', help='Filenames to check')
args = parser.parse_args(argv)
@@ -25,4 +27,4 @@
if __name__ == '__main__':
- exit(check_symlinks())
+ exit(main())
diff --git a/pre_commit_hooks/check_vcs_permalinks.py b/pre_commit_hooks/check_vcs_permalinks.py
index f0dcf5b..f6e2a7d 100644
--- a/pre_commit_hooks/check_vcs_permalinks.py
+++ b/pre_commit_hooks/check_vcs_permalinks.py
@@ -5,6 +5,8 @@
import argparse
import re
import sys
+from typing import Optional
+from typing import Sequence
GITHUB_NON_PERMALINK = re.compile(
@@ -12,7 +14,7 @@
)
-def _check_filename(filename):
+def _check_filename(filename): # type: (str) -> int
retv = 0
with open(filename, 'rb') as f:
for i, line in enumerate(f, 1):
@@ -24,7 +26,7 @@
return retv
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/check_xml.py b/pre_commit_hooks/check_xml.py
index a4c11a5..66e10ba 100644
--- a/pre_commit_hooks/check_xml.py
+++ b/pre_commit_hooks/check_xml.py
@@ -5,10 +5,12 @@
import argparse
import io
import sys
-import xml.sax
+import xml.sax.handler
+from typing import Optional
+from typing import Sequence
-def check_xml(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='XML filenames to check.')
args = parser.parse_args(argv)
@@ -17,7 +19,7 @@
for filename in args.filenames:
try:
with io.open(filename, 'rb') as xml_file:
- xml.sax.parse(xml_file, xml.sax.ContentHandler())
+ xml.sax.parse(xml_file, xml.sax.handler.ContentHandler())
except xml.sax.SAXException as exc:
print('{}: Failed to xml parse ({})'.format(filename, exc))
retval = 1
@@ -25,4 +27,4 @@
if __name__ == '__main__':
- sys.exit(check_xml())
+ sys.exit(main())
diff --git a/pre_commit_hooks/check_yaml.py b/pre_commit_hooks/check_yaml.py
index 208737f..b638684 100644
--- a/pre_commit_hooks/check_yaml.py
+++ b/pre_commit_hooks/check_yaml.py
@@ -3,22 +3,26 @@
import argparse
import collections
import sys
+from typing import Any
+from typing import Generator
+from typing import Optional
+from typing import Sequence
import ruamel.yaml
yaml = ruamel.yaml.YAML(typ='safe')
-def _exhaust(gen):
+def _exhaust(gen): # type: (Generator[str, None, None]) -> None
for _ in gen:
pass
-def _parse_unsafe(*args, **kwargs):
+def _parse_unsafe(*args, **kwargs): # type: (*Any, **Any) -> None
_exhaust(yaml.parse(*args, **kwargs))
-def _load_all(*args, **kwargs):
+def _load_all(*args, **kwargs): # type: (*Any, **Any) -> None
_exhaust(yaml.load_all(*args, **kwargs))
@@ -31,7 +35,7 @@
}
-def check_yaml(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument(
'-m', '--multi', '--allow-multiple-documents', action='store_true',
@@ -63,4 +67,4 @@
if __name__ == '__main__':
- sys.exit(check_yaml())
+ sys.exit(main())
diff --git a/pre_commit_hooks/debug_statement_hook.py b/pre_commit_hooks/debug_statement_hook.py
index 5d32277..02dd3b2 100644
--- a/pre_commit_hooks/debug_statement_hook.py
+++ b/pre_commit_hooks/debug_statement_hook.py
@@ -5,6 +5,9 @@
import ast
import collections
import traceback
+from typing import List
+from typing import Optional
+from typing import Sequence
DEBUG_STATEMENTS = {'pdb', 'ipdb', 'pudb', 'q', 'rdb'}
@@ -12,21 +15,21 @@
class DebugStatementParser(ast.NodeVisitor):
- def __init__(self):
- self.breakpoints = []
+ def __init__(self): # type: () -> None
+ self.breakpoints = [] # type: List[Debug]
- def visit_Import(self, node):
+ def visit_Import(self, node): # type: (ast.Import) -> None
for name in node.names:
if name.name in DEBUG_STATEMENTS:
st = Debug(node.lineno, node.col_offset, name.name, 'imported')
self.breakpoints.append(st)
- def visit_ImportFrom(self, node):
+ def visit_ImportFrom(self, node): # type: (ast.ImportFrom) -> None
if node.module in DEBUG_STATEMENTS:
st = Debug(node.lineno, node.col_offset, node.module, 'imported')
self.breakpoints.append(st)
- def visit_Call(self, node):
+ def visit_Call(self, node): # type: (ast.Call) -> None
"""python3.7+ breakpoint()"""
if isinstance(node.func, ast.Name) and node.func.id == 'breakpoint':
st = Debug(node.lineno, node.col_offset, node.func.id, 'called')
@@ -34,7 +37,7 @@
self.generic_visit(node)
-def check_file(filename):
+def check_file(filename): # type: (str) -> int
try:
with open(filename, 'rb') as f:
ast_obj = ast.parse(f.read(), filename=filename)
@@ -58,7 +61,7 @@
return int(bool(visitor.breakpoints))
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to run')
args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/detect_aws_credentials.py b/pre_commit_hooks/detect_aws_credentials.py
index ecd9d40..3c87d11 100644
--- a/pre_commit_hooks/detect_aws_credentials.py
+++ b/pre_commit_hooks/detect_aws_credentials.py
@@ -3,11 +3,16 @@
import argparse
import os
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Set
from six.moves import configparser
-def get_aws_credential_files_from_env():
+def get_aws_credential_files_from_env(): # type: () -> Set[str]
"""Extract credential file paths from environment variables."""
files = set()
for env_var in (
@@ -19,7 +24,7 @@
return files
-def get_aws_secrets_from_env():
+def get_aws_secrets_from_env(): # type: () -> Set[str]
"""Extract AWS secrets from environment variables."""
keys = set()
for env_var in (
@@ -30,7 +35,7 @@
return keys
-def get_aws_secrets_from_file(credentials_file):
+def get_aws_secrets_from_file(credentials_file): # type: (str) -> Set[str]
"""Extract AWS secrets from configuration files.
Read an ini-style configuration file and return a set with all found AWS
@@ -62,6 +67,7 @@
def check_file_for_aws_keys(filenames, keys):
+ # type: (Sequence[str], Set[str]) -> List[Dict[str, str]]
"""Check if files contain AWS secrets.
Return a list of all files containing AWS secrets and keys found, with all
@@ -82,7 +88,7 @@
return bad_files
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='+', help='Filenames to run')
parser.add_argument(
@@ -111,7 +117,7 @@
# of files to to gather AWS secrets from.
credential_files |= get_aws_credential_files_from_env()
- keys = set()
+ keys = set() # type: Set[str]
for credential_file in credential_files:
keys |= get_aws_secrets_from_file(credential_file)
diff --git a/pre_commit_hooks/detect_private_key.py b/pre_commit_hooks/detect_private_key.py
index c8ee961..d31957d 100644
--- a/pre_commit_hooks/detect_private_key.py
+++ b/pre_commit_hooks/detect_private_key.py
@@ -2,6 +2,8 @@
import argparse
import sys
+from typing import Optional
+from typing import Sequence
BLACKLIST = [
b'BEGIN RSA PRIVATE KEY',
@@ -15,7 +17,7 @@
]
-def detect_private_key(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to check')
args = parser.parse_args(argv)
@@ -37,4 +39,4 @@
if __name__ == '__main__':
- sys.exit(detect_private_key())
+ sys.exit(main())
diff --git a/pre_commit_hooks/end_of_file_fixer.py b/pre_commit_hooks/end_of_file_fixer.py
index 5ab1b7b..4e77c94 100644
--- a/pre_commit_hooks/end_of_file_fixer.py
+++ b/pre_commit_hooks/end_of_file_fixer.py
@@ -4,9 +4,12 @@
import argparse
import os
import sys
+from typing import IO
+from typing import Optional
+from typing import Sequence
-def fix_file(file_obj):
+def fix_file(file_obj): # type: (IO[bytes]) -> int
# Test for newline at end of file
# Empty files will throw IOError here
try:
@@ -49,7 +52,7 @@
return 0
-def end_of_file_fixer(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv)
@@ -68,4 +71,4 @@
if __name__ == '__main__':
- sys.exit(end_of_file_fixer())
+ sys.exit(main())
diff --git a/pre_commit_hooks/file_contents_sorter.py b/pre_commit_hooks/file_contents_sorter.py
index fe7f7ee..6f13c98 100644
--- a/pre_commit_hooks/file_contents_sorter.py
+++ b/pre_commit_hooks/file_contents_sorter.py
@@ -12,12 +12,15 @@
from __future__ import print_function
import argparse
+from typing import IO
+from typing import Optional
+from typing import Sequence
PASS = 0
FAIL = 1
-def sort_file_contents(f):
+def sort_file_contents(f): # type: (IO[bytes]) -> int
before = list(f)
after = sorted([line.strip(b'\n\r') for line in before if line.strip()])
@@ -33,7 +36,7 @@
return FAIL
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='+', help='Files to sort')
args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/fix_encoding_pragma.py b/pre_commit_hooks/fix_encoding_pragma.py
index 3bf234e..b0b5c8e 100644
--- a/pre_commit_hooks/fix_encoding_pragma.py
+++ b/pre_commit_hooks/fix_encoding_pragma.py
@@ -4,11 +4,15 @@
import argparse
import collections
+from typing import IO
+from typing import Optional
+from typing import Sequence
+from typing import Union
DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-\n'
-def has_coding(line):
+def has_coding(line): # type: (bytes) -> bool
if not line.strip():
return False
return (
@@ -33,15 +37,16 @@
__slots__ = ()
@property
- def has_any_pragma(self):
+ def has_any_pragma(self): # type: () -> bool
return self.pragma_status is not False
- def is_expected_pragma(self, remove):
+ def is_expected_pragma(self, remove): # type: (bool) -> bool
expected_pragma_status = not remove
return self.pragma_status is expected_pragma_status
def _get_expected_contents(first_line, second_line, rest, expected_pragma):
+ # type: (bytes, bytes, bytes, bytes) -> ExpectedContents
if first_line.startswith(b'#!'):
shebang = first_line
potential_coding = second_line
@@ -51,7 +56,7 @@
rest = second_line + rest
if potential_coding == expected_pragma:
- pragma_status = True
+ pragma_status = True # type: Optional[bool]
elif has_coding(potential_coding):
pragma_status = None
else:
@@ -64,6 +69,7 @@
def fix_encoding_pragma(f, remove=False, expected_pragma=DEFAULT_PRAGMA):
+ # type: (IO[bytes], bool, bytes) -> int
expected = _get_expected_contents(
f.readline(), f.readline(), f.read(), expected_pragma,
)
@@ -93,17 +99,17 @@
return 1
-def _normalize_pragma(pragma):
+def _normalize_pragma(pragma): # type: (Union[bytes, str]) -> bytes
if not isinstance(pragma, bytes):
pragma = pragma.encode('UTF-8')
return pragma.rstrip() + b'\n'
-def _to_disp(pragma):
+def _to_disp(pragma): # type: (bytes) -> str
return pragma.decode().rstrip()
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser('Fixes the encoding pragma of python files')
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
parser.add_argument(
diff --git a/pre_commit_hooks/forbid_new_submodules.py b/pre_commit_hooks/forbid_new_submodules.py
index c9464cf..bdbd6f7 100644
--- a/pre_commit_hooks/forbid_new_submodules.py
+++ b/pre_commit_hooks/forbid_new_submodules.py
@@ -2,10 +2,13 @@
from __future__ import print_function
from __future__ import unicode_literals
+from typing import Optional
+from typing import Sequence
+
from pre_commit_hooks.util import cmd_output
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
# `argv` is ignored, pre-commit will send us a list of files that we
# don't care about
added_diff = cmd_output(
diff --git a/pre_commit_hooks/mixed_line_ending.py b/pre_commit_hooks/mixed_line_ending.py
index e35a65c..90aef03 100644
--- a/pre_commit_hooks/mixed_line_ending.py
+++ b/pre_commit_hooks/mixed_line_ending.py
@@ -4,6 +4,9 @@
import argparse
import collections
+from typing import Dict
+from typing import Optional
+from typing import Sequence
CRLF = b'\r\n'
@@ -14,7 +17,7 @@
FIX_TO_LINE_ENDING = {'cr': CR, 'crlf': CRLF, 'lf': LF}
-def _fix(filename, contents, ending):
+def _fix(filename, contents, ending): # type: (str, bytes, bytes) -> None
new_contents = b''.join(
line.rstrip(b'\r\n') + ending for line in contents.splitlines(True)
)
@@ -22,11 +25,11 @@
f.write(new_contents)
-def fix_filename(filename, fix):
+def fix_filename(filename, fix): # type: (str, str) -> int
with open(filename, 'rb') as f:
contents = f.read()
- counts = collections.defaultdict(int)
+ counts = collections.defaultdict(int) # type: Dict[bytes, int]
for line in contents.splitlines(True):
for ending in ALL_ENDINGS:
@@ -63,7 +66,7 @@
return other_endings
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument(
'-f', '--fix',
diff --git a/pre_commit_hooks/no_commit_to_branch.py b/pre_commit_hooks/no_commit_to_branch.py
index fdd146b..6b68c91 100644
--- a/pre_commit_hooks/no_commit_to_branch.py
+++ b/pre_commit_hooks/no_commit_to_branch.py
@@ -1,12 +1,15 @@
from __future__ import print_function
import argparse
+from typing import Optional
+from typing import Sequence
+from typing import Set
from pre_commit_hooks.util import CalledProcessError
from pre_commit_hooks.util import cmd_output
-def is_on_branch(protected):
+def is_on_branch(protected): # type: (Set[str]) -> bool
try:
branch = cmd_output('git', 'symbolic-ref', 'HEAD')
except CalledProcessError:
@@ -15,7 +18,7 @@
return '/'.join(chunks[2:]) in protected
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument(
'-b', '--branch', action='append',
diff --git a/pre_commit_hooks/pretty_format_json.py b/pre_commit_hooks/pretty_format_json.py
index 363037e..de7f8d7 100644
--- a/pre_commit_hooks/pretty_format_json.py
+++ b/pre_commit_hooks/pretty_format_json.py
@@ -5,12 +5,20 @@
import json
import sys
from collections import OrderedDict
+from typing import List
+from typing import Mapping
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
from six import text_type
-def _get_pretty_format(contents, indent, ensure_ascii=True, sort_keys=True, top_keys=[]):
+def _get_pretty_format(contents, indent, ensure_ascii=True, sort_keys=True, top_keys=()):
+ # type: (str, str, bool, bool, Sequence[str]) -> str
def pairs_first(pairs):
+ # type: (Sequence[Tuple[str, str]]) -> Mapping[str, str]
before = [pair for pair in pairs if pair[0] in top_keys]
before = sorted(before, key=lambda x: top_keys.index(x[0]))
after = [pair for pair in pairs if pair[0] not in top_keys]
@@ -27,13 +35,13 @@
return text_type(json_pretty) + '\n'
-def _autofix(filename, new_contents):
+def _autofix(filename, new_contents): # type: (str, str) -> None
print('Fixing file {}'.format(filename))
with io.open(filename, 'w', encoding='UTF-8') as f:
f.write(new_contents)
-def parse_num_to_int(s):
+def parse_num_to_int(s): # type: (str) -> Union[int, str]
"""Convert string numbers to int, leaving strings as is."""
try:
return int(s)
@@ -41,11 +49,11 @@
return s
-def parse_topkeys(s):
+def parse_topkeys(s): # type: (str) -> List[str]
return s.split(',')
-def pretty_format_json(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument(
'--autofix',
@@ -117,4 +125,4 @@
if __name__ == '__main__':
- sys.exit(pretty_format_json())
+ sys.exit(main())
diff --git a/pre_commit_hooks/requirements_txt_fixer.py b/pre_commit_hooks/requirements_txt_fixer.py
index 6dcf8d0..3f85a17 100644
--- a/pre_commit_hooks/requirements_txt_fixer.py
+++ b/pre_commit_hooks/requirements_txt_fixer.py
@@ -1,6 +1,10 @@
from __future__ import print_function
import argparse
+from typing import IO
+from typing import List
+from typing import Optional
+from typing import Sequence
PASS = 0
@@ -9,21 +13,23 @@
class Requirement(object):
- def __init__(self):
+ def __init__(self): # type: () -> None
super(Requirement, self).__init__()
- self.value = None
- self.comments = []
+ self.value = None # type: Optional[bytes]
+ self.comments = [] # type: List[bytes]
@property
- def name(self):
+ def name(self): # type: () -> bytes
+ assert self.value is not None, self.value
if self.value.startswith(b'-e '):
return self.value.lower().partition(b'=')[-1]
return self.value.lower().partition(b'==')[0]
- def __lt__(self, requirement):
+ def __lt__(self, requirement): # type: (Requirement) -> int
# \n means top of file comment, so always return True,
# otherwise just do a string comparison with value.
+ assert self.value is not None, self.value
if self.value == b'\n':
return True
elif requirement.value == b'\n':
@@ -32,10 +38,10 @@
return self.name < requirement.name
-def fix_requirements(f):
- requirements = []
+def fix_requirements(f): # type: (IO[bytes]) -> int
+ requirements = [] # type: List[Requirement]
before = tuple(f)
- after = []
+ after = [] # type: List[bytes]
before_string = b''.join(before)
@@ -46,6 +52,7 @@
for line in before:
# If the most recent requirement object has a value, then it's
# time to start building the next requirement object.
+
if not len(requirements) or requirements[-1].value is not None:
requirements.append(Requirement())
@@ -78,6 +85,7 @@
for requirement in sorted(requirements):
after.extend(requirement.comments)
+ assert requirement.value, requirement.value
after.append(requirement.value)
after.extend(rest)
@@ -92,7 +100,7 @@
return FAIL
-def fix_requirements_txt(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv)
@@ -109,3 +117,7 @@
retv |= ret_for_file
return retv
+
+
+if __name__ == '__main__':
+ exit(main())
diff --git a/pre_commit_hooks/sort_simple_yaml.py b/pre_commit_hooks/sort_simple_yaml.py
index 7afae91..3c8ef16 100755
--- a/pre_commit_hooks/sort_simple_yaml.py
+++ b/pre_commit_hooks/sort_simple_yaml.py
@@ -21,12 +21,15 @@
from __future__ import print_function
import argparse
+from typing import List
+from typing import Optional
+from typing import Sequence
QUOTES = ["'", '"']
-def sort(lines):
+def sort(lines): # type: (List[str]) -> List[str]
"""Sort a YAML file in alphabetical order, keeping blocks together.
:param lines: array of strings (without newlines)
@@ -44,7 +47,7 @@
return new_lines
-def parse_block(lines, header=False):
+def parse_block(lines, header=False): # type: (List[str], bool) -> List[str]
"""Parse and return a single block, popping off the start of `lines`.
If parsing a header block, we stop after we reach a line that is not a
@@ -60,7 +63,7 @@
return block_lines
-def parse_blocks(lines):
+def parse_blocks(lines): # type: (List[str]) -> List[List[str]]
"""Parse and return all possible blocks, popping off the start of `lines`.
:param lines: list of lines
@@ -77,7 +80,7 @@
return blocks
-def first_key(lines):
+def first_key(lines): # type: (List[str]) -> str
"""Returns a string representing the sort key of a block.
The sort key is the first YAML key we encounter, ignoring comments, and
@@ -95,9 +98,11 @@
if any(line.startswith(quote) for quote in QUOTES):
return line[1:]
return line
+ else:
+ return '' # not actually reached in reality
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/string_fixer.py b/pre_commit_hooks/string_fixer.py
index c432682..a5ea1ea 100644
--- a/pre_commit_hooks/string_fixer.py
+++ b/pre_commit_hooks/string_fixer.py
@@ -4,34 +4,39 @@
import argparse
import io
+import re
import tokenize
+from typing import List
+from typing import Optional
+from typing import Sequence
+
+START_QUOTE_RE = re.compile('^[a-zA-Z]*"')
-double_quote_starts = tuple(s for s in tokenize.single_quoted if '"' in s)
-
-
-def handle_match(token_text):
+def handle_match(token_text): # type: (str) -> str
if '"""' in token_text or "'''" in token_text:
return token_text
- for double_quote_start in double_quote_starts:
- if token_text.startswith(double_quote_start):
- meat = token_text[len(double_quote_start):-1]
- if '"' in meat or "'" in meat:
- break
- return double_quote_start.replace('"', "'") + meat + "'"
- return token_text
+ match = START_QUOTE_RE.match(token_text)
+ if match is not None:
+ meat = token_text[match.end():-1]
+ if '"' in meat or "'" in meat:
+ return token_text
+ else:
+ return match.group().replace('"', "'") + meat + "'"
+ else:
+ return token_text
-def get_line_offsets_by_line_no(src):
+def get_line_offsets_by_line_no(src): # type: (str) -> List[int]
# Padded so we can index with line number
- offsets = [None, 0]
+ offsets = [-1, 0]
for line in src.splitlines():
offsets.append(offsets[-1] + len(line) + 1)
return offsets
-def fix_strings(filename):
+def fix_strings(filename): # type: (str) -> int
with io.open(filename, encoding='UTF-8') as f:
contents = f.read()
line_offsets = get_line_offsets_by_line_no(contents)
@@ -60,7 +65,7 @@
return 0
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv)
@@ -74,3 +79,7 @@
retv |= return_value
return retv
+
+
+if __name__ == '__main__':
+ exit(main())
diff --git a/pre_commit_hooks/tests_should_end_in_test.py b/pre_commit_hooks/tests_should_end_in_test.py
index 9bea20d..7a1e7c0 100644
--- a/pre_commit_hooks/tests_should_end_in_test.py
+++ b/pre_commit_hooks/tests_should_end_in_test.py
@@ -1,12 +1,14 @@
from __future__ import print_function
import argparse
+import os.path
import re
import sys
-from os.path import basename
+from typing import Optional
+from typing import Sequence
-def validate_files(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
parser.add_argument(
@@ -18,7 +20,7 @@
retcode = 0
test_name_pattern = 'test.*.py' if args.django else '.*_test.py'
for filename in args.filenames:
- base = basename(filename)
+ base = os.path.basename(filename)
if (
not re.match(test_name_pattern, base) and
not base == '__init__.py' and
@@ -35,4 +37,4 @@
if __name__ == '__main__':
- sys.exit(validate_files())
+ sys.exit(main())
diff --git a/pre_commit_hooks/trailing_whitespace_fixer.py b/pre_commit_hooks/trailing_whitespace_fixer.py
index 1b54fbd..4fe7975 100644
--- a/pre_commit_hooks/trailing_whitespace_fixer.py
+++ b/pre_commit_hooks/trailing_whitespace_fixer.py
@@ -3,9 +3,11 @@
import argparse
import os
import sys
+from typing import Optional
+from typing import Sequence
-def _fix_file(filename, is_markdown):
+def _fix_file(filename, is_markdown): # type: (str, bool) -> bool
with open(filename, mode='rb') as file_processed:
lines = file_processed.readlines()
newlines = [_process_line(line, is_markdown) for line in lines]
@@ -18,7 +20,7 @@
return False
-def _process_line(line, is_markdown):
+def _process_line(line, is_markdown): # type: (bytes, bool) -> bytes
if line[-2:] == b'\r\n':
eol = b'\r\n'
elif line[-1:] == b'\n':
@@ -31,7 +33,7 @@
return line.rstrip() + eol
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument(
'--no-markdown-linebreak-ext',
diff --git a/pre_commit_hooks/util.py b/pre_commit_hooks/util.py
index 269b553..5d1d11b 100644
--- a/pre_commit_hooks/util.py
+++ b/pre_commit_hooks/util.py
@@ -3,23 +3,25 @@
from __future__ import unicode_literals
import subprocess
+from typing import Any
+from typing import Set
class CalledProcessError(RuntimeError):
pass
-def added_files():
+def added_files(): # type: () -> Set[str]
return set(cmd_output(
'git', 'diff', '--staged', '--name-only', '--diff-filter=A',
).splitlines())
-def cmd_output(*cmd, **kwargs):
+def cmd_output(*cmd, **kwargs): # type: (*str, **Any) -> str
retcode = kwargs.pop('retcode', 0)
- popen_kwargs = {'stdout': subprocess.PIPE, 'stderr': subprocess.PIPE}
- popen_kwargs.update(kwargs)
- proc = subprocess.Popen(cmd, **popen_kwargs)
+ kwargs.setdefault('stdout', subprocess.PIPE)
+ kwargs.setdefault('stderr', subprocess.PIPE)
+ proc = subprocess.Popen(cmd, **kwargs)
stdout, stderr = proc.communicate()
stdout = stdout.decode('UTF-8')
if stderr is not None: