Apply typing to all of pre-commit-hooks
diff --git a/.gitignore b/.gitignore
index c00e966..32c2fec 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,16 +1,11 @@
*.egg-info
-*.iml
*.py[co]
.*.sw[a-z]
-.pytest_cache
.coverage
-.idea
-.project
-.pydevproject
.tox
.venv.touch
+/.mypy_cache
+/.pytest_cache
/venv*
coverage-html
dist
-# SublimeText project/workspace files
-*.sublime-*
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 8bd0fdc..4990537 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -27,7 +27,7 @@
rev: v1.3.5
hooks:
- id: reorder-python-imports
- language_version: python2.7
+ language_version: python3
- repo: https://github.com/asottile/pyupgrade
rev: v1.11.1
hooks:
@@ -36,3 +36,8 @@
rev: v0.7.1
hooks:
- id: add-trailing-comma
+- repo: https://github.com/pre-commit/mirrors-mypy
+ rev: v0.660
+ hooks:
+ - id: mypy
+ language_version: python3
diff --git a/.travis.yml b/.travis.yml
index 477b5c4..fa16cce 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,3 +1,4 @@
+dist: xenial
language: python
matrix:
include: # These should match the tox env list
@@ -6,9 +7,8 @@
python: 3.6
- env: TOXENV=py37
python: 3.7
- dist: xenial
- env: TOXENV=pypy
- python: pypy-5.7.1
+ python: pypy2.7-5.10.0
install: pip install coveralls tox
script: tox
before_install:
diff --git a/get-git-lfs.py b/get-git-lfs.py
index 48dd31e..4b09cac 100755
--- a/get-git-lfs.py
+++ b/get-git-lfs.py
@@ -4,7 +4,9 @@
import os.path
import shutil
import tarfile
-from urllib.request import urlopen
+import urllib.request
+from typing import cast
+from typing import IO
DOWNLOAD_PATH = (
'https://github.com/github/git-lfs/releases/download/'
@@ -15,7 +17,7 @@
DEST_DIR = os.path.dirname(DEST_PATH)
-def main():
+def main(): # type: () -> int
if (
os.path.exists(DEST_PATH) and
os.path.isfile(DEST_PATH) and
@@ -27,12 +29,13 @@
shutil.rmtree(DEST_DIR, ignore_errors=True)
os.makedirs(DEST_DIR, exist_ok=True)
- contents = io.BytesIO(urlopen(DOWNLOAD_PATH).read())
+ contents = io.BytesIO(urllib.request.urlopen(DOWNLOAD_PATH).read())
with tarfile.open(fileobj=contents) as tar:
- with tar.extractfile(PATH_IN_TAR) as src_file:
+ with cast(IO[bytes], tar.extractfile(PATH_IN_TAR)) as src_file:
with open(DEST_PATH, 'wb') as dest_file:
shutil.copyfileobj(src_file, dest_file)
os.chmod(DEST_PATH, 0o755)
+ return 0
if __name__ == '__main__':
diff --git a/mypy.ini b/mypy.ini
new file mode 100644
index 0000000..ee62c89
--- /dev/null
+++ b/mypy.ini
@@ -0,0 +1,12 @@
+[mypy]
+check_untyped_defs = true
+disallow_any_generics = true
+disallow_incomplete_defs = true
+disallow_untyped_defs = true
+no_implicit_optional = true
+
+[mypy-testing.*]
+disallow_untyped_defs = false
+
+[mypy-tests.*]
+disallow_untyped_defs = false
diff --git a/pre_commit_hooks/autopep8_wrapper.py b/pre_commit_hooks/autopep8_wrapper.py
index 9951924..8b69a04 100644
--- a/pre_commit_hooks/autopep8_wrapper.py
+++ b/pre_commit_hooks/autopep8_wrapper.py
@@ -3,7 +3,7 @@
from __future__ import unicode_literals
-def main(argv=None):
+def main(): # type: () -> int
raise SystemExit(
'autopep8-wrapper is deprecated. Instead use autopep8 directly via '
'https://github.com/pre-commit/mirrors-autopep8',
diff --git a/pre_commit_hooks/check_added_large_files.py b/pre_commit_hooks/check_added_large_files.py
index 2d06706..be39498 100644
--- a/pre_commit_hooks/check_added_large_files.py
+++ b/pre_commit_hooks/check_added_large_files.py
@@ -7,13 +7,17 @@
import json
import math
import os
+from typing import Iterable
+from typing import Optional
+from typing import Sequence
+from typing import Set
from pre_commit_hooks.util import added_files
from pre_commit_hooks.util import CalledProcessError
from pre_commit_hooks.util import cmd_output
-def lfs_files():
+def lfs_files(): # type: () -> Set[str]
try:
# Introduced in git-lfs 2.2.0, first working in 2.2.1
lfs_ret = cmd_output('git', 'lfs', 'status', '--json')
@@ -24,6 +28,7 @@
def find_large_added_files(filenames, maxkb):
+ # type: (Iterable[str], int) -> int
# Find all added files that are also in the list of files pre-commit tells
# us about
filenames = (added_files() & set(filenames)) - lfs_files()
@@ -38,7 +43,7 @@
return retv
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument(
'filenames', nargs='*',
diff --git a/pre_commit_hooks/check_ast.py b/pre_commit_hooks/check_ast.py
index ded65e4..0df3540 100644
--- a/pre_commit_hooks/check_ast.py
+++ b/pre_commit_hooks/check_ast.py
@@ -7,9 +7,11 @@
import platform
import sys
import traceback
+from typing import Optional
+from typing import Sequence
-def check_ast(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
args = parser.parse_args(argv)
@@ -34,4 +36,4 @@
if __name__ == '__main__':
- exit(check_ast())
+ exit(main())
diff --git a/pre_commit_hooks/check_builtin_literals.py b/pre_commit_hooks/check_builtin_literals.py
index 4a4b9ce..874c68c 100644
--- a/pre_commit_hooks/check_builtin_literals.py
+++ b/pre_commit_hooks/check_builtin_literals.py
@@ -4,6 +4,10 @@
import ast
import collections
import sys
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Set
BUILTIN_TYPES = {
@@ -22,14 +26,17 @@
class BuiltinTypeVisitor(ast.NodeVisitor):
def __init__(self, ignore=None, allow_dict_kwargs=True):
- self.builtin_type_calls = []
+ # type: (Optional[Sequence[str]], bool) -> None
+ self.builtin_type_calls = [] # type: List[BuiltinTypeCall]
self.ignore = set(ignore) if ignore else set()
self.allow_dict_kwargs = allow_dict_kwargs
- def _check_dict_call(self, node):
+ def _check_dict_call(self, node): # type: (ast.Call) -> bool
+
return self.allow_dict_kwargs and (getattr(node, 'kwargs', None) or getattr(node, 'keywords', None))
- def visit_Call(self, node):
+ def visit_Call(self, node): # type: (ast.Call) -> None
+
if not isinstance(node.func, ast.Name):
# Ignore functions that are object attributes (`foo.bar()`).
# Assume that if the user calls `builtins.list()`, they know what
@@ -47,6 +54,7 @@
def check_file_for_builtin_type_constructors(filename, ignore=None, allow_dict_kwargs=True):
+ # type: (str, Optional[Sequence[str]], bool) -> List[BuiltinTypeCall]
with open(filename, 'rb') as f:
tree = ast.parse(f.read(), filename=filename)
visitor = BuiltinTypeVisitor(ignore=ignore, allow_dict_kwargs=allow_dict_kwargs)
@@ -54,24 +62,22 @@
return visitor.builtin_type_calls
-def parse_args(argv):
- def parse_ignore(value):
- return set(value.split(','))
+def parse_ignore(value): # type: (str) -> Set[str]
+ return set(value.split(','))
+
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
parser.add_argument('--ignore', type=parse_ignore, default=set())
- allow_dict_kwargs = parser.add_mutually_exclusive_group(required=False)
- allow_dict_kwargs.add_argument('--allow-dict-kwargs', action='store_true')
- allow_dict_kwargs.add_argument('--no-allow-dict-kwargs', dest='allow_dict_kwargs', action='store_false')
- allow_dict_kwargs.set_defaults(allow_dict_kwargs=True)
+ mutex = parser.add_mutually_exclusive_group(required=False)
+ mutex.add_argument('--allow-dict-kwargs', action='store_true')
+ mutex.add_argument('--no-allow-dict-kwargs', dest='allow_dict_kwargs', action='store_false')
+ mutex.set_defaults(allow_dict_kwargs=True)
- return parser.parse_args(argv)
+ args = parser.parse_args(argv)
-
-def main(argv=None):
- args = parse_args(argv)
rc = 0
for filename in args.filenames:
calls = check_file_for_builtin_type_constructors(
diff --git a/pre_commit_hooks/check_byte_order_marker.py b/pre_commit_hooks/check_byte_order_marker.py
index 1541b30..10667c3 100644
--- a/pre_commit_hooks/check_byte_order_marker.py
+++ b/pre_commit_hooks/check_byte_order_marker.py
@@ -3,9 +3,11 @@
from __future__ import unicode_literals
import argparse
+from typing import Optional
+from typing import Sequence
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to check')
args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/check_case_conflict.py b/pre_commit_hooks/check_case_conflict.py
index 0f78296..e343d61 100644
--- a/pre_commit_hooks/check_case_conflict.py
+++ b/pre_commit_hooks/check_case_conflict.py
@@ -3,16 +3,20 @@
from __future__ import unicode_literals
import argparse
+from typing import Iterable
+from typing import Optional
+from typing import Sequence
+from typing import Set
from pre_commit_hooks.util import added_files
from pre_commit_hooks.util import cmd_output
-def lower_set(iterable):
+def lower_set(iterable): # type: (Iterable[str]) -> Set[str]
return {x.lower() for x in iterable}
-def find_conflicting_filenames(filenames):
+def find_conflicting_filenames(filenames): # type: (Sequence[str]) -> int
repo_files = set(cmd_output('git', 'ls-files').splitlines())
relevant_files = set(filenames) | added_files()
repo_files -= relevant_files
@@ -41,7 +45,7 @@
return retv
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument(
'filenames', nargs='*',
diff --git a/pre_commit_hooks/check_docstring_first.py b/pre_commit_hooks/check_docstring_first.py
index 9988378..f4639f1 100644
--- a/pre_commit_hooks/check_docstring_first.py
+++ b/pre_commit_hooks/check_docstring_first.py
@@ -5,6 +5,8 @@
import argparse
import io
import tokenize
+from typing import Optional
+from typing import Sequence
NON_CODE_TOKENS = frozenset((
@@ -13,6 +15,7 @@
def check_docstring_first(src, filename='<unknown>'):
+ # type: (str, str) -> int
"""Returns nonzero if the source has what looks like a docstring that is
not at the beginning of the source.
@@ -50,7 +53,7 @@
return 0
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/check_executables_have_shebangs.py b/pre_commit_hooks/check_executables_have_shebangs.py
index 89ac6e5..c936a5d 100644
--- a/pre_commit_hooks/check_executables_have_shebangs.py
+++ b/pre_commit_hooks/check_executables_have_shebangs.py
@@ -6,9 +6,11 @@
import argparse
import pipes
import sys
+from typing import Optional
+from typing import Sequence
-def check_has_shebang(path):
+def check_has_shebang(path): # type: (str) -> int
with open(path, 'rb') as f:
first_bytes = f.read(2)
@@ -27,7 +29,7 @@
return 0
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('filenames', nargs='*')
args = parser.parse_args(argv)
@@ -38,3 +40,7 @@
retv |= check_has_shebang(filename)
return retv
+
+
+if __name__ == '__main__':
+ exit(main())
diff --git a/pre_commit_hooks/check_json.py b/pre_commit_hooks/check_json.py
index b403f4b..b939350 100644
--- a/pre_commit_hooks/check_json.py
+++ b/pre_commit_hooks/check_json.py
@@ -4,9 +4,11 @@
import io
import json
import sys
+from typing import Optional
+from typing import Sequence
-def check_json(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='JSON filenames to check.')
args = parser.parse_args(argv)
@@ -22,4 +24,4 @@
if __name__ == '__main__':
- sys.exit(check_json())
+ sys.exit(main())
diff --git a/pre_commit_hooks/check_merge_conflict.py b/pre_commit_hooks/check_merge_conflict.py
index 6db5efe..74e4ae1 100644
--- a/pre_commit_hooks/check_merge_conflict.py
+++ b/pre_commit_hooks/check_merge_conflict.py
@@ -2,6 +2,9 @@
import argparse
import os.path
+from typing import Optional
+from typing import Sequence
+
CONFLICT_PATTERNS = [
b'<<<<<<< ',
@@ -12,7 +15,7 @@
WARNING_MSG = 'Merge conflict string "{0}" found in {1}:{2}'
-def is_in_merge():
+def is_in_merge(): # type: () -> int
return (
os.path.exists(os.path.join('.git', 'MERGE_MSG')) and
(
@@ -23,7 +26,7 @@
)
-def detect_merge_conflict(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
parser.add_argument('--assume-in-merge', action='store_true')
@@ -47,4 +50,4 @@
if __name__ == '__main__':
- exit(detect_merge_conflict())
+ exit(main())
diff --git a/pre_commit_hooks/check_symlinks.py b/pre_commit_hooks/check_symlinks.py
index 010c871..736bf99 100644
--- a/pre_commit_hooks/check_symlinks.py
+++ b/pre_commit_hooks/check_symlinks.py
@@ -4,9 +4,11 @@
import argparse
import os.path
+from typing import Optional
+from typing import Sequence
-def check_symlinks(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser(description='Checks for broken symlinks.')
parser.add_argument('filenames', nargs='*', help='Filenames to check')
args = parser.parse_args(argv)
@@ -25,4 +27,4 @@
if __name__ == '__main__':
- exit(check_symlinks())
+ exit(main())
diff --git a/pre_commit_hooks/check_vcs_permalinks.py b/pre_commit_hooks/check_vcs_permalinks.py
index f0dcf5b..f6e2a7d 100644
--- a/pre_commit_hooks/check_vcs_permalinks.py
+++ b/pre_commit_hooks/check_vcs_permalinks.py
@@ -5,6 +5,8 @@
import argparse
import re
import sys
+from typing import Optional
+from typing import Sequence
GITHUB_NON_PERMALINK = re.compile(
@@ -12,7 +14,7 @@
)
-def _check_filename(filename):
+def _check_filename(filename): # type: (str) -> int
retv = 0
with open(filename, 'rb') as f:
for i, line in enumerate(f, 1):
@@ -24,7 +26,7 @@
return retv
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/check_xml.py b/pre_commit_hooks/check_xml.py
index a4c11a5..66e10ba 100644
--- a/pre_commit_hooks/check_xml.py
+++ b/pre_commit_hooks/check_xml.py
@@ -5,10 +5,12 @@
import argparse
import io
import sys
-import xml.sax
+import xml.sax.handler
+from typing import Optional
+from typing import Sequence
-def check_xml(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='XML filenames to check.')
args = parser.parse_args(argv)
@@ -17,7 +19,7 @@
for filename in args.filenames:
try:
with io.open(filename, 'rb') as xml_file:
- xml.sax.parse(xml_file, xml.sax.ContentHandler())
+ xml.sax.parse(xml_file, xml.sax.handler.ContentHandler())
except xml.sax.SAXException as exc:
print('{}: Failed to xml parse ({})'.format(filename, exc))
retval = 1
@@ -25,4 +27,4 @@
if __name__ == '__main__':
- sys.exit(check_xml())
+ sys.exit(main())
diff --git a/pre_commit_hooks/check_yaml.py b/pre_commit_hooks/check_yaml.py
index 208737f..b638684 100644
--- a/pre_commit_hooks/check_yaml.py
+++ b/pre_commit_hooks/check_yaml.py
@@ -3,22 +3,26 @@
import argparse
import collections
import sys
+from typing import Any
+from typing import Generator
+from typing import Optional
+from typing import Sequence
import ruamel.yaml
yaml = ruamel.yaml.YAML(typ='safe')
-def _exhaust(gen):
+def _exhaust(gen): # type: (Generator[str, None, None]) -> None
for _ in gen:
pass
-def _parse_unsafe(*args, **kwargs):
+def _parse_unsafe(*args, **kwargs): # type: (*Any, **Any) -> None
_exhaust(yaml.parse(*args, **kwargs))
-def _load_all(*args, **kwargs):
+def _load_all(*args, **kwargs): # type: (*Any, **Any) -> None
_exhaust(yaml.load_all(*args, **kwargs))
@@ -31,7 +35,7 @@
}
-def check_yaml(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument(
'-m', '--multi', '--allow-multiple-documents', action='store_true',
@@ -63,4 +67,4 @@
if __name__ == '__main__':
- sys.exit(check_yaml())
+ sys.exit(main())
diff --git a/pre_commit_hooks/debug_statement_hook.py b/pre_commit_hooks/debug_statement_hook.py
index 5d32277..02dd3b2 100644
--- a/pre_commit_hooks/debug_statement_hook.py
+++ b/pre_commit_hooks/debug_statement_hook.py
@@ -5,6 +5,9 @@
import ast
import collections
import traceback
+from typing import List
+from typing import Optional
+from typing import Sequence
DEBUG_STATEMENTS = {'pdb', 'ipdb', 'pudb', 'q', 'rdb'}
@@ -12,21 +15,21 @@
class DebugStatementParser(ast.NodeVisitor):
- def __init__(self):
- self.breakpoints = []
+ def __init__(self): # type: () -> None
+ self.breakpoints = [] # type: List[Debug]
- def visit_Import(self, node):
+ def visit_Import(self, node): # type: (ast.Import) -> None
for name in node.names:
if name.name in DEBUG_STATEMENTS:
st = Debug(node.lineno, node.col_offset, name.name, 'imported')
self.breakpoints.append(st)
- def visit_ImportFrom(self, node):
+ def visit_ImportFrom(self, node): # type: (ast.ImportFrom) -> None
if node.module in DEBUG_STATEMENTS:
st = Debug(node.lineno, node.col_offset, node.module, 'imported')
self.breakpoints.append(st)
- def visit_Call(self, node):
+ def visit_Call(self, node): # type: (ast.Call) -> None
"""python3.7+ breakpoint()"""
if isinstance(node.func, ast.Name) and node.func.id == 'breakpoint':
st = Debug(node.lineno, node.col_offset, node.func.id, 'called')
@@ -34,7 +37,7 @@
self.generic_visit(node)
-def check_file(filename):
+def check_file(filename): # type: (str) -> int
try:
with open(filename, 'rb') as f:
ast_obj = ast.parse(f.read(), filename=filename)
@@ -58,7 +61,7 @@
return int(bool(visitor.breakpoints))
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to run')
args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/detect_aws_credentials.py b/pre_commit_hooks/detect_aws_credentials.py
index ecd9d40..3c87d11 100644
--- a/pre_commit_hooks/detect_aws_credentials.py
+++ b/pre_commit_hooks/detect_aws_credentials.py
@@ -3,11 +3,16 @@
import argparse
import os
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Set
from six.moves import configparser
-def get_aws_credential_files_from_env():
+def get_aws_credential_files_from_env(): # type: () -> Set[str]
"""Extract credential file paths from environment variables."""
files = set()
for env_var in (
@@ -19,7 +24,7 @@
return files
-def get_aws_secrets_from_env():
+def get_aws_secrets_from_env(): # type: () -> Set[str]
"""Extract AWS secrets from environment variables."""
keys = set()
for env_var in (
@@ -30,7 +35,7 @@
return keys
-def get_aws_secrets_from_file(credentials_file):
+def get_aws_secrets_from_file(credentials_file): # type: (str) -> Set[str]
"""Extract AWS secrets from configuration files.
Read an ini-style configuration file and return a set with all found AWS
@@ -62,6 +67,7 @@
def check_file_for_aws_keys(filenames, keys):
+ # type: (Sequence[str], Set[str]) -> List[Dict[str, str]]
"""Check if files contain AWS secrets.
Return a list of all files containing AWS secrets and keys found, with all
@@ -82,7 +88,7 @@
return bad_files
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='+', help='Filenames to run')
parser.add_argument(
@@ -111,7 +117,7 @@
# of files to to gather AWS secrets from.
credential_files |= get_aws_credential_files_from_env()
- keys = set()
+ keys = set() # type: Set[str]
for credential_file in credential_files:
keys |= get_aws_secrets_from_file(credential_file)
diff --git a/pre_commit_hooks/detect_private_key.py b/pre_commit_hooks/detect_private_key.py
index c8ee961..d31957d 100644
--- a/pre_commit_hooks/detect_private_key.py
+++ b/pre_commit_hooks/detect_private_key.py
@@ -2,6 +2,8 @@
import argparse
import sys
+from typing import Optional
+from typing import Sequence
BLACKLIST = [
b'BEGIN RSA PRIVATE KEY',
@@ -15,7 +17,7 @@
]
-def detect_private_key(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to check')
args = parser.parse_args(argv)
@@ -37,4 +39,4 @@
if __name__ == '__main__':
- sys.exit(detect_private_key())
+ sys.exit(main())
diff --git a/pre_commit_hooks/end_of_file_fixer.py b/pre_commit_hooks/end_of_file_fixer.py
index 5ab1b7b..4e77c94 100644
--- a/pre_commit_hooks/end_of_file_fixer.py
+++ b/pre_commit_hooks/end_of_file_fixer.py
@@ -4,9 +4,12 @@
import argparse
import os
import sys
+from typing import IO
+from typing import Optional
+from typing import Sequence
-def fix_file(file_obj):
+def fix_file(file_obj): # type: (IO[bytes]) -> int
# Test for newline at end of file
# Empty files will throw IOError here
try:
@@ -49,7 +52,7 @@
return 0
-def end_of_file_fixer(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv)
@@ -68,4 +71,4 @@
if __name__ == '__main__':
- sys.exit(end_of_file_fixer())
+ sys.exit(main())
diff --git a/pre_commit_hooks/file_contents_sorter.py b/pre_commit_hooks/file_contents_sorter.py
index fe7f7ee..6f13c98 100644
--- a/pre_commit_hooks/file_contents_sorter.py
+++ b/pre_commit_hooks/file_contents_sorter.py
@@ -12,12 +12,15 @@
from __future__ import print_function
import argparse
+from typing import IO
+from typing import Optional
+from typing import Sequence
PASS = 0
FAIL = 1
-def sort_file_contents(f):
+def sort_file_contents(f): # type: (IO[bytes]) -> int
before = list(f)
after = sorted([line.strip(b'\n\r') for line in before if line.strip()])
@@ -33,7 +36,7 @@
return FAIL
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='+', help='Files to sort')
args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/fix_encoding_pragma.py b/pre_commit_hooks/fix_encoding_pragma.py
index 3bf234e..b0b5c8e 100644
--- a/pre_commit_hooks/fix_encoding_pragma.py
+++ b/pre_commit_hooks/fix_encoding_pragma.py
@@ -4,11 +4,15 @@
import argparse
import collections
+from typing import IO
+from typing import Optional
+from typing import Sequence
+from typing import Union
DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-\n'
-def has_coding(line):
+def has_coding(line): # type: (bytes) -> bool
if not line.strip():
return False
return (
@@ -33,15 +37,16 @@
__slots__ = ()
@property
- def has_any_pragma(self):
+ def has_any_pragma(self): # type: () -> bool
return self.pragma_status is not False
- def is_expected_pragma(self, remove):
+ def is_expected_pragma(self, remove): # type: (bool) -> bool
expected_pragma_status = not remove
return self.pragma_status is expected_pragma_status
def _get_expected_contents(first_line, second_line, rest, expected_pragma):
+ # type: (bytes, bytes, bytes, bytes) -> ExpectedContents
if first_line.startswith(b'#!'):
shebang = first_line
potential_coding = second_line
@@ -51,7 +56,7 @@
rest = second_line + rest
if potential_coding == expected_pragma:
- pragma_status = True
+ pragma_status = True # type: Optional[bool]
elif has_coding(potential_coding):
pragma_status = None
else:
@@ -64,6 +69,7 @@
def fix_encoding_pragma(f, remove=False, expected_pragma=DEFAULT_PRAGMA):
+ # type: (IO[bytes], bool, bytes) -> int
expected = _get_expected_contents(
f.readline(), f.readline(), f.read(), expected_pragma,
)
@@ -93,17 +99,17 @@
return 1
-def _normalize_pragma(pragma):
+def _normalize_pragma(pragma): # type: (Union[bytes, str]) -> bytes
if not isinstance(pragma, bytes):
pragma = pragma.encode('UTF-8')
return pragma.rstrip() + b'\n'
-def _to_disp(pragma):
+def _to_disp(pragma): # type: (bytes) -> str
return pragma.decode().rstrip()
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser('Fixes the encoding pragma of python files')
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
parser.add_argument(
diff --git a/pre_commit_hooks/forbid_new_submodules.py b/pre_commit_hooks/forbid_new_submodules.py
index c9464cf..bdbd6f7 100644
--- a/pre_commit_hooks/forbid_new_submodules.py
+++ b/pre_commit_hooks/forbid_new_submodules.py
@@ -2,10 +2,13 @@
from __future__ import print_function
from __future__ import unicode_literals
+from typing import Optional
+from typing import Sequence
+
from pre_commit_hooks.util import cmd_output
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
# `argv` is ignored, pre-commit will send us a list of files that we
# don't care about
added_diff = cmd_output(
diff --git a/pre_commit_hooks/mixed_line_ending.py b/pre_commit_hooks/mixed_line_ending.py
index e35a65c..90aef03 100644
--- a/pre_commit_hooks/mixed_line_ending.py
+++ b/pre_commit_hooks/mixed_line_ending.py
@@ -4,6 +4,9 @@
import argparse
import collections
+from typing import Dict
+from typing import Optional
+from typing import Sequence
CRLF = b'\r\n'
@@ -14,7 +17,7 @@
FIX_TO_LINE_ENDING = {'cr': CR, 'crlf': CRLF, 'lf': LF}
-def _fix(filename, contents, ending):
+def _fix(filename, contents, ending): # type: (str, bytes, bytes) -> None
new_contents = b''.join(
line.rstrip(b'\r\n') + ending for line in contents.splitlines(True)
)
@@ -22,11 +25,11 @@
f.write(new_contents)
-def fix_filename(filename, fix):
+def fix_filename(filename, fix): # type: (str, str) -> int
with open(filename, 'rb') as f:
contents = f.read()
- counts = collections.defaultdict(int)
+ counts = collections.defaultdict(int) # type: Dict[bytes, int]
for line in contents.splitlines(True):
for ending in ALL_ENDINGS:
@@ -63,7 +66,7 @@
return other_endings
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument(
'-f', '--fix',
diff --git a/pre_commit_hooks/no_commit_to_branch.py b/pre_commit_hooks/no_commit_to_branch.py
index fdd146b..6b68c91 100644
--- a/pre_commit_hooks/no_commit_to_branch.py
+++ b/pre_commit_hooks/no_commit_to_branch.py
@@ -1,12 +1,15 @@
from __future__ import print_function
import argparse
+from typing import Optional
+from typing import Sequence
+from typing import Set
from pre_commit_hooks.util import CalledProcessError
from pre_commit_hooks.util import cmd_output
-def is_on_branch(protected):
+def is_on_branch(protected): # type: (Set[str]) -> bool
try:
branch = cmd_output('git', 'symbolic-ref', 'HEAD')
except CalledProcessError:
@@ -15,7 +18,7 @@
return '/'.join(chunks[2:]) in protected
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument(
'-b', '--branch', action='append',
diff --git a/pre_commit_hooks/pretty_format_json.py b/pre_commit_hooks/pretty_format_json.py
index 363037e..de7f8d7 100644
--- a/pre_commit_hooks/pretty_format_json.py
+++ b/pre_commit_hooks/pretty_format_json.py
@@ -5,12 +5,20 @@
import json
import sys
from collections import OrderedDict
+from typing import List
+from typing import Mapping
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
from six import text_type
-def _get_pretty_format(contents, indent, ensure_ascii=True, sort_keys=True, top_keys=[]):
+def _get_pretty_format(contents, indent, ensure_ascii=True, sort_keys=True, top_keys=()):
+ # type: (str, str, bool, bool, Sequence[str]) -> str
def pairs_first(pairs):
+ # type: (Sequence[Tuple[str, str]]) -> Mapping[str, str]
before = [pair for pair in pairs if pair[0] in top_keys]
before = sorted(before, key=lambda x: top_keys.index(x[0]))
after = [pair for pair in pairs if pair[0] not in top_keys]
@@ -27,13 +35,13 @@
return text_type(json_pretty) + '\n'
-def _autofix(filename, new_contents):
+def _autofix(filename, new_contents): # type: (str, str) -> None
print('Fixing file {}'.format(filename))
with io.open(filename, 'w', encoding='UTF-8') as f:
f.write(new_contents)
-def parse_num_to_int(s):
+def parse_num_to_int(s): # type: (str) -> Union[int, str]
"""Convert string numbers to int, leaving strings as is."""
try:
return int(s)
@@ -41,11 +49,11 @@
return s
-def parse_topkeys(s):
+def parse_topkeys(s): # type: (str) -> List[str]
return s.split(',')
-def pretty_format_json(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument(
'--autofix',
@@ -117,4 +125,4 @@
if __name__ == '__main__':
- sys.exit(pretty_format_json())
+ sys.exit(main())
diff --git a/pre_commit_hooks/requirements_txt_fixer.py b/pre_commit_hooks/requirements_txt_fixer.py
index 6dcf8d0..3f85a17 100644
--- a/pre_commit_hooks/requirements_txt_fixer.py
+++ b/pre_commit_hooks/requirements_txt_fixer.py
@@ -1,6 +1,10 @@
from __future__ import print_function
import argparse
+from typing import IO
+from typing import List
+from typing import Optional
+from typing import Sequence
PASS = 0
@@ -9,21 +13,23 @@
class Requirement(object):
- def __init__(self):
+ def __init__(self): # type: () -> None
super(Requirement, self).__init__()
- self.value = None
- self.comments = []
+ self.value = None # type: Optional[bytes]
+ self.comments = [] # type: List[bytes]
@property
- def name(self):
+ def name(self): # type: () -> bytes
+ assert self.value is not None, self.value
if self.value.startswith(b'-e '):
return self.value.lower().partition(b'=')[-1]
return self.value.lower().partition(b'==')[0]
- def __lt__(self, requirement):
+ def __lt__(self, requirement): # type: (Requirement) -> int
# \n means top of file comment, so always return True,
# otherwise just do a string comparison with value.
+ assert self.value is not None, self.value
if self.value == b'\n':
return True
elif requirement.value == b'\n':
@@ -32,10 +38,10 @@
return self.name < requirement.name
-def fix_requirements(f):
- requirements = []
+def fix_requirements(f): # type: (IO[bytes]) -> int
+ requirements = [] # type: List[Requirement]
before = tuple(f)
- after = []
+ after = [] # type: List[bytes]
before_string = b''.join(before)
@@ -46,6 +52,7 @@
for line in before:
# If the most recent requirement object has a value, then it's
# time to start building the next requirement object.
+
if not len(requirements) or requirements[-1].value is not None:
requirements.append(Requirement())
@@ -78,6 +85,7 @@
for requirement in sorted(requirements):
after.extend(requirement.comments)
+ assert requirement.value, requirement.value
after.append(requirement.value)
after.extend(rest)
@@ -92,7 +100,7 @@
return FAIL
-def fix_requirements_txt(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv)
@@ -109,3 +117,7 @@
retv |= ret_for_file
return retv
+
+
+if __name__ == '__main__':
+ exit(main())
diff --git a/pre_commit_hooks/sort_simple_yaml.py b/pre_commit_hooks/sort_simple_yaml.py
index 7afae91..3c8ef16 100755
--- a/pre_commit_hooks/sort_simple_yaml.py
+++ b/pre_commit_hooks/sort_simple_yaml.py
@@ -21,12 +21,15 @@
from __future__ import print_function
import argparse
+from typing import List
+from typing import Optional
+from typing import Sequence
QUOTES = ["'", '"']
-def sort(lines):
+def sort(lines): # type: (List[str]) -> List[str]
"""Sort a YAML file in alphabetical order, keeping blocks together.
:param lines: array of strings (without newlines)
@@ -44,7 +47,7 @@
return new_lines
-def parse_block(lines, header=False):
+def parse_block(lines, header=False): # type: (List[str], bool) -> List[str]
"""Parse and return a single block, popping off the start of `lines`.
If parsing a header block, we stop after we reach a line that is not a
@@ -60,7 +63,7 @@
return block_lines
-def parse_blocks(lines):
+def parse_blocks(lines): # type: (List[str]) -> List[List[str]]
"""Parse and return all possible blocks, popping off the start of `lines`.
:param lines: list of lines
@@ -77,7 +80,7 @@
return blocks
-def first_key(lines):
+def first_key(lines): # type: (List[str]) -> str
"""Returns a string representing the sort key of a block.
The sort key is the first YAML key we encounter, ignoring comments, and
@@ -95,9 +98,11 @@
if any(line.startswith(quote) for quote in QUOTES):
return line[1:]
return line
+ else:
+ return '' # not actually reached in reality
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv)
diff --git a/pre_commit_hooks/string_fixer.py b/pre_commit_hooks/string_fixer.py
index c432682..a5ea1ea 100644
--- a/pre_commit_hooks/string_fixer.py
+++ b/pre_commit_hooks/string_fixer.py
@@ -4,34 +4,39 @@
import argparse
import io
+import re
import tokenize
+from typing import List
+from typing import Optional
+from typing import Sequence
+
+START_QUOTE_RE = re.compile('^[a-zA-Z]*"')
-double_quote_starts = tuple(s for s in tokenize.single_quoted if '"' in s)
-
-
-def handle_match(token_text):
+def handle_match(token_text): # type: (str) -> str
if '"""' in token_text or "'''" in token_text:
return token_text
- for double_quote_start in double_quote_starts:
- if token_text.startswith(double_quote_start):
- meat = token_text[len(double_quote_start):-1]
- if '"' in meat or "'" in meat:
- break
- return double_quote_start.replace('"', "'") + meat + "'"
- return token_text
+ match = START_QUOTE_RE.match(token_text)
+ if match is not None:
+ meat = token_text[match.end():-1]
+ if '"' in meat or "'" in meat:
+ return token_text
+ else:
+ return match.group().replace('"', "'") + meat + "'"
+ else:
+ return token_text
-def get_line_offsets_by_line_no(src):
+def get_line_offsets_by_line_no(src): # type: (str) -> List[int]
# Padded so we can index with line number
- offsets = [None, 0]
+ offsets = [-1, 0]
for line in src.splitlines():
offsets.append(offsets[-1] + len(line) + 1)
return offsets
-def fix_strings(filename):
+def fix_strings(filename): # type: (str) -> int
with io.open(filename, encoding='UTF-8') as f:
contents = f.read()
line_offsets = get_line_offsets_by_line_no(contents)
@@ -60,7 +65,7 @@
return 0
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv)
@@ -74,3 +79,7 @@
retv |= return_value
return retv
+
+
+if __name__ == '__main__':
+ exit(main())
diff --git a/pre_commit_hooks/tests_should_end_in_test.py b/pre_commit_hooks/tests_should_end_in_test.py
index 9bea20d..7a1e7c0 100644
--- a/pre_commit_hooks/tests_should_end_in_test.py
+++ b/pre_commit_hooks/tests_should_end_in_test.py
@@ -1,12 +1,14 @@
from __future__ import print_function
import argparse
+import os.path
import re
import sys
-from os.path import basename
+from typing import Optional
+from typing import Sequence
-def validate_files(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
parser.add_argument(
@@ -18,7 +20,7 @@
retcode = 0
test_name_pattern = 'test.*.py' if args.django else '.*_test.py'
for filename in args.filenames:
- base = basename(filename)
+ base = os.path.basename(filename)
if (
not re.match(test_name_pattern, base) and
not base == '__init__.py' and
@@ -35,4 +37,4 @@
if __name__ == '__main__':
- sys.exit(validate_files())
+ sys.exit(main())
diff --git a/pre_commit_hooks/trailing_whitespace_fixer.py b/pre_commit_hooks/trailing_whitespace_fixer.py
index 1b54fbd..4fe7975 100644
--- a/pre_commit_hooks/trailing_whitespace_fixer.py
+++ b/pre_commit_hooks/trailing_whitespace_fixer.py
@@ -3,9 +3,11 @@
import argparse
import os
import sys
+from typing import Optional
+from typing import Sequence
-def _fix_file(filename, is_markdown):
+def _fix_file(filename, is_markdown): # type: (str, bool) -> bool
with open(filename, mode='rb') as file_processed:
lines = file_processed.readlines()
newlines = [_process_line(line, is_markdown) for line in lines]
@@ -18,7 +20,7 @@
return False
-def _process_line(line, is_markdown):
+def _process_line(line, is_markdown): # type: (bytes, bool) -> bytes
if line[-2:] == b'\r\n':
eol = b'\r\n'
elif line[-1:] == b'\n':
@@ -31,7 +33,7 @@
return line.rstrip() + eol
-def main(argv=None):
+def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument(
'--no-markdown-linebreak-ext',
diff --git a/pre_commit_hooks/util.py b/pre_commit_hooks/util.py
index 269b553..5d1d11b 100644
--- a/pre_commit_hooks/util.py
+++ b/pre_commit_hooks/util.py
@@ -3,23 +3,25 @@
from __future__ import unicode_literals
import subprocess
+from typing import Any
+from typing import Set
class CalledProcessError(RuntimeError):
pass
-def added_files():
+def added_files(): # type: () -> Set[str]
return set(cmd_output(
'git', 'diff', '--staged', '--name-only', '--diff-filter=A',
).splitlines())
-def cmd_output(*cmd, **kwargs):
+def cmd_output(*cmd, **kwargs): # type: (*str, **Any) -> str
retcode = kwargs.pop('retcode', 0)
- popen_kwargs = {'stdout': subprocess.PIPE, 'stderr': subprocess.PIPE}
- popen_kwargs.update(kwargs)
- proc = subprocess.Popen(cmd, **popen_kwargs)
+ kwargs.setdefault('stdout', subprocess.PIPE)
+ kwargs.setdefault('stderr', subprocess.PIPE)
+ proc = subprocess.Popen(cmd, **kwargs)
stdout, stderr = proc.communicate()
stdout = stdout.decode('UTF-8')
if stderr is not None:
diff --git a/setup.py b/setup.py
index 84892a7..756500b 100644
--- a/setup.py
+++ b/setup.py
@@ -28,35 +28,36 @@
'ruamel.yaml>=0.15',
'six',
],
+ extras_require={':python_version<"3.5"': ['typing']},
entry_points={
'console_scripts': [
'autopep8-wrapper = pre_commit_hooks.autopep8_wrapper:main',
'check-added-large-files = pre_commit_hooks.check_added_large_files:main',
- 'check-ast = pre_commit_hooks.check_ast:check_ast',
+ 'check-ast = pre_commit_hooks.check_ast:main',
'check-builtin-literals = pre_commit_hooks.check_builtin_literals:main',
'check-byte-order-marker = pre_commit_hooks.check_byte_order_marker:main',
'check-case-conflict = pre_commit_hooks.check_case_conflict:main',
'check-docstring-first = pre_commit_hooks.check_docstring_first:main',
'check-executables-have-shebangs = pre_commit_hooks.check_executables_have_shebangs:main',
- 'check-json = pre_commit_hooks.check_json:check_json',
- 'check-merge-conflict = pre_commit_hooks.check_merge_conflict:detect_merge_conflict',
- 'check-symlinks = pre_commit_hooks.check_symlinks:check_symlinks',
+ 'check-json = pre_commit_hooks.check_json:main',
+ 'check-merge-conflict = pre_commit_hooks.check_merge_conflict:main',
+ 'check-symlinks = pre_commit_hooks.check_symlinks:main',
'check-vcs-permalinks = pre_commit_hooks.check_vcs_permalinks:main',
- 'check-xml = pre_commit_hooks.check_xml:check_xml',
- 'check-yaml = pre_commit_hooks.check_yaml:check_yaml',
+ 'check-xml = pre_commit_hooks.check_xml:main',
+ 'check-yaml = pre_commit_hooks.check_yaml:main',
'debug-statement-hook = pre_commit_hooks.debug_statement_hook:main',
'detect-aws-credentials = pre_commit_hooks.detect_aws_credentials:main',
- 'detect-private-key = pre_commit_hooks.detect_private_key:detect_private_key',
+ 'detect-private-key = pre_commit_hooks.detect_private_key:main',
'double-quote-string-fixer = pre_commit_hooks.string_fixer:main',
- 'end-of-file-fixer = pre_commit_hooks.end_of_file_fixer:end_of_file_fixer',
+ 'end-of-file-fixer = pre_commit_hooks.end_of_file_fixer:main',
'file-contents-sorter = pre_commit_hooks.file_contents_sorter:main',
'fix-encoding-pragma = pre_commit_hooks.fix_encoding_pragma:main',
'forbid-new-submodules = pre_commit_hooks.forbid_new_submodules:main',
'mixed-line-ending = pre_commit_hooks.mixed_line_ending:main',
- 'name-tests-test = pre_commit_hooks.tests_should_end_in_test:validate_files',
+ 'name-tests-test = pre_commit_hooks.tests_should_end_in_test:main',
'no-commit-to-branch = pre_commit_hooks.no_commit_to_branch:main',
- 'pretty-format-json = pre_commit_hooks.pretty_format_json:pretty_format_json',
- 'requirements-txt-fixer = pre_commit_hooks.requirements_txt_fixer:fix_requirements_txt',
+ 'pretty-format-json = pre_commit_hooks.pretty_format_json:main',
+ 'requirements-txt-fixer = pre_commit_hooks.requirements_txt_fixer:main',
'sort-simple-yaml = pre_commit_hooks.sort_simple_yaml:main',
'trailing-whitespace-fixer = pre_commit_hooks.trailing_whitespace_fixer:main',
],
diff --git a/testing/resources/bad_json_latin1.nonjson b/testing/resources/bad_json_latin1.nonjson
old mode 100755
new mode 100644
diff --git a/testing/resources/builtin_constructors.py b/testing/resources/builtin_constructors.py
deleted file mode 100644
index 174a9e8..0000000
--- a/testing/resources/builtin_constructors.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from six.moves import builtins
-
-c1 = complex()
-d1 = dict()
-f1 = float()
-i1 = int()
-l1 = list()
-s1 = str()
-t1 = tuple()
-
-c2 = builtins.complex()
-d2 = builtins.dict()
-f2 = builtins.float()
-i2 = builtins.int()
-l2 = builtins.list()
-s2 = builtins.str()
-t2 = builtins.tuple()
diff --git a/testing/resources/builtin_literals.py b/testing/resources/builtin_literals.py
deleted file mode 100644
index 8513b70..0000000
--- a/testing/resources/builtin_literals.py
+++ /dev/null
@@ -1,7 +0,0 @@
-c1 = 0j
-d1 = {}
-f1 = 0.0
-i1 = 0
-l1 = []
-s1 = ''
-t1 = ()
diff --git a/tests/check_ast_test.py b/tests/check_ast_test.py
index 64916ba..c16f5fc 100644
--- a/tests/check_ast_test.py
+++ b/tests/check_ast_test.py
@@ -1,15 +1,15 @@
from __future__ import absolute_import
from __future__ import unicode_literals
-from pre_commit_hooks.check_ast import check_ast
+from pre_commit_hooks.check_ast import main
from testing.util import get_resource_path
def test_failing_file():
- ret = check_ast([get_resource_path('cannot_parse_ast.notpy')])
+ ret = main([get_resource_path('cannot_parse_ast.notpy')])
assert ret == 1
def test_passing_file():
- ret = check_ast([__file__])
+ ret = main([__file__])
assert ret == 0
diff --git a/tests/check_builtin_literals_test.py b/tests/check_builtin_literals_test.py
index 86b79e3..d4ac30f 100644
--- a/tests/check_builtin_literals_test.py
+++ b/tests/check_builtin_literals_test.py
@@ -5,7 +5,35 @@
from pre_commit_hooks.check_builtin_literals import BuiltinTypeCall
from pre_commit_hooks.check_builtin_literals import BuiltinTypeVisitor
from pre_commit_hooks.check_builtin_literals import main
-from testing.util import get_resource_path
+
+BUILTIN_CONSTRUCTORS = '''\
+from six.moves import builtins
+
+c1 = complex()
+d1 = dict()
+f1 = float()
+i1 = int()
+l1 = list()
+s1 = str()
+t1 = tuple()
+
+c2 = builtins.complex()
+d2 = builtins.dict()
+f2 = builtins.float()
+i2 = builtins.int()
+l2 = builtins.list()
+s2 = builtins.str()
+t2 = builtins.tuple()
+'''
+BUILTIN_LITERALS = '''\
+c1 = 0j
+d1 = {}
+f1 = 0.0
+i1 = 0
+l1 = []
+s1 = ''
+t1 = ()
+'''
@pytest.fixture
@@ -94,24 +122,26 @@
def test_ignore_constructors():
visitor = BuiltinTypeVisitor(ignore=('complex', 'dict', 'float', 'int', 'list', 'str', 'tuple'))
- with open(get_resource_path('builtin_constructors.py'), 'rb') as f:
- visitor.visit(ast.parse(f.read(), 'builtin_constructors.py'))
+ visitor.visit(ast.parse(BUILTIN_CONSTRUCTORS))
assert visitor.builtin_type_calls == []
-def test_failing_file():
- rc = main([get_resource_path('builtin_constructors.py')])
+def test_failing_file(tmpdir):
+ f = tmpdir.join('f.py')
+ f.write(BUILTIN_CONSTRUCTORS)
+ rc = main([f.strpath])
assert rc == 1
-def test_passing_file():
- rc = main([get_resource_path('builtin_literals.py')])
+def test_passing_file(tmpdir):
+ f = tmpdir.join('f.py')
+ f.write(BUILTIN_LITERALS)
+ rc = main([f.strpath])
assert rc == 0
-def test_failing_file_ignore_all():
- rc = main([
- '--ignore=complex,dict,float,int,list,str,tuple',
- get_resource_path('builtin_constructors.py'),
- ])
+def test_failing_file_ignore_all(tmpdir):
+ f = tmpdir.join('f.py')
+ f.write(BUILTIN_CONSTRUCTORS)
+ rc = main(['--ignore=complex,dict,float,int,list,str,tuple', f.strpath])
assert rc == 0
diff --git a/tests/check_json_test.py b/tests/check_json_test.py
index 6ba26c1..6654ed1 100644
--- a/tests/check_json_test.py
+++ b/tests/check_json_test.py
@@ -1,6 +1,6 @@
import pytest
-from pre_commit_hooks.check_json import check_json
+from pre_commit_hooks.check_json import main
from testing.util import get_resource_path
@@ -11,8 +11,8 @@
('ok_json.json', 0),
),
)
-def test_check_json(capsys, filename, expected_retval):
- ret = check_json([get_resource_path(filename)])
+def test_main(capsys, filename, expected_retval):
+ ret = main([get_resource_path(filename)])
assert ret == expected_retval
if expected_retval == 1:
stdout, _ = capsys.readouterr()
diff --git a/tests/check_merge_conflict_test.py b/tests/check_merge_conflict_test.py
index b04c70e..50e389c 100644
--- a/tests/check_merge_conflict_test.py
+++ b/tests/check_merge_conflict_test.py
@@ -6,7 +6,7 @@
import pytest
-from pre_commit_hooks.check_merge_conflict import detect_merge_conflict
+from pre_commit_hooks.check_merge_conflict import main
from pre_commit_hooks.util import cmd_output
from testing.util import get_resource_path
@@ -102,7 +102,7 @@
@pytest.mark.usefixtures('f1_is_a_conflict_file')
def test_merge_conflicts_git():
- assert detect_merge_conflict(['f1']) == 1
+ assert main(['f1']) == 1
@pytest.mark.parametrize(
@@ -110,7 +110,7 @@
)
def test_merge_conflicts_failing(contents, repository_pending_merge):
repository_pending_merge.join('f2').write_binary(contents)
- assert detect_merge_conflict(['f2']) == 1
+ assert main(['f2']) == 1
@pytest.mark.parametrize(
@@ -118,22 +118,22 @@
)
def test_merge_conflicts_ok(contents, f1_is_a_conflict_file):
f1_is_a_conflict_file.join('f1').write_binary(contents)
- assert detect_merge_conflict(['f1']) == 0
+ assert main(['f1']) == 0
@pytest.mark.usefixtures('f1_is_a_conflict_file')
def test_ignores_binary_files():
shutil.copy(get_resource_path('img1.jpg'), 'f1')
- assert detect_merge_conflict(['f1']) == 0
+ assert main(['f1']) == 0
def test_does_not_care_when_not_in_a_merge(tmpdir):
f = tmpdir.join('README.md')
f.write_binary(b'problem\n=======\n')
- assert detect_merge_conflict([str(f.realpath())]) == 0
+ assert main([str(f.realpath())]) == 0
def test_care_when_assumed_merge(tmpdir):
f = tmpdir.join('README.md')
f.write_binary(b'problem\n=======\n')
- assert detect_merge_conflict([str(f.realpath()), '--assume-in-merge']) == 1
+ assert main([str(f.realpath()), '--assume-in-merge']) == 1
diff --git a/tests/check_symlinks_test.py b/tests/check_symlinks_test.py
index 0414df5..ecbc7ae 100644
--- a/tests/check_symlinks_test.py
+++ b/tests/check_symlinks_test.py
@@ -2,7 +2,7 @@
import pytest
-from pre_commit_hooks.check_symlinks import check_symlinks
+from pre_commit_hooks.check_symlinks import main
xfail_symlink = pytest.mark.xfail(os.name == 'nt', reason='No symlink support')
@@ -12,12 +12,12 @@
@pytest.mark.parametrize(
('dest', 'expected'), (('exists', 0), ('does-not-exist', 1)),
)
-def test_check_symlinks(tmpdir, dest, expected): # pragma: no cover (symlinks)
+def test_main(tmpdir, dest, expected): # pragma: no cover (symlinks)
tmpdir.join('exists').ensure()
symlink = tmpdir.join('symlink')
symlink.mksymlinkto(tmpdir.join(dest))
- assert check_symlinks((symlink.strpath,)) == expected
+ assert main((symlink.strpath,)) == expected
-def test_check_symlinks_normal_file(tmpdir):
- assert check_symlinks((tmpdir.join('f').ensure().strpath,)) == 0
+def test_main_normal_file(tmpdir):
+ assert main((tmpdir.join('f').ensure().strpath,)) == 0
diff --git a/tests/check_xml_test.py b/tests/check_xml_test.py
index 84e365d..357bad6 100644
--- a/tests/check_xml_test.py
+++ b/tests/check_xml_test.py
@@ -1,6 +1,6 @@
import pytest
-from pre_commit_hooks.check_xml import check_xml
+from pre_commit_hooks.check_xml import main
from testing.util import get_resource_path
@@ -10,6 +10,6 @@
('ok_xml.xml', 0),
),
)
-def test_check_xml(filename, expected_retval):
- ret = check_xml([get_resource_path(filename)])
+def test_main(filename, expected_retval):
+ ret = main([get_resource_path(filename)])
assert ret == expected_retval
diff --git a/tests/check_yaml_test.py b/tests/check_yaml_test.py
index aa357f1..d267150 100644
--- a/tests/check_yaml_test.py
+++ b/tests/check_yaml_test.py
@@ -3,7 +3,7 @@
import pytest
-from pre_commit_hooks.check_yaml import check_yaml
+from pre_commit_hooks.check_yaml import main
from testing.util import get_resource_path
@@ -13,29 +13,29 @@
('ok_yaml.yaml', 0),
),
)
-def test_check_yaml(filename, expected_retval):
- ret = check_yaml([get_resource_path(filename)])
+def test_main(filename, expected_retval):
+ ret = main([get_resource_path(filename)])
assert ret == expected_retval
-def test_check_yaml_allow_multiple_documents(tmpdir):
+def test_main_allow_multiple_documents(tmpdir):
f = tmpdir.join('test.yaml')
f.write('---\nfoo\n---\nbar\n')
# should fail without the setting
- assert check_yaml((f.strpath,))
+ assert main((f.strpath,))
# should pass when we allow multiple documents
- assert not check_yaml(('--allow-multiple-documents', f.strpath))
+ assert not main(('--allow-multiple-documents', f.strpath))
def test_fails_even_with_allow_multiple_documents(tmpdir):
f = tmpdir.join('test.yaml')
f.write('[')
- assert check_yaml(('--allow-multiple-documents', f.strpath))
+ assert main(('--allow-multiple-documents', f.strpath))
-def test_check_yaml_unsafe(tmpdir):
+def test_main_unsafe(tmpdir):
f = tmpdir.join('test.yaml')
f.write(
'some_foo: !vault |\n'
@@ -43,12 +43,12 @@
' deadbeefdeadbeefdeadbeef\n',
)
# should fail "safe" check
- assert check_yaml((f.strpath,))
+ assert main((f.strpath,))
# should pass when we allow unsafe documents
- assert not check_yaml(('--unsafe', f.strpath))
+ assert not main(('--unsafe', f.strpath))
-def test_check_yaml_unsafe_still_fails_on_syntax_errors(tmpdir):
+def test_main_unsafe_still_fails_on_syntax_errors(tmpdir):
f = tmpdir.join('test.yaml')
f.write('[')
- assert check_yaml(('--unsafe', f.strpath))
+ assert main(('--unsafe', f.strpath))
diff --git a/tests/detect_private_key_test.py b/tests/detect_private_key_test.py
index fdd63a2..9266f2b 100644
--- a/tests/detect_private_key_test.py
+++ b/tests/detect_private_key_test.py
@@ -1,6 +1,6 @@
import pytest
-from pre_commit_hooks.detect_private_key import detect_private_key
+from pre_commit_hooks.detect_private_key import main
# Input, expected return value
TESTS = (
@@ -18,7 +18,7 @@
@pytest.mark.parametrize(('input_s', 'expected_retval'), TESTS)
-def test_detect_private_key(input_s, expected_retval, tmpdir):
+def test_main(input_s, expected_retval, tmpdir):
path = tmpdir.join('file.txt')
path.write_binary(input_s)
- assert detect_private_key([path.strpath]) == expected_retval
+ assert main([path.strpath]) == expected_retval
diff --git a/tests/end_of_file_fixer_test.py b/tests/end_of_file_fixer_test.py
index f8710af..7f644e7 100644
--- a/tests/end_of_file_fixer_test.py
+++ b/tests/end_of_file_fixer_test.py
@@ -2,8 +2,8 @@
import pytest
-from pre_commit_hooks.end_of_file_fixer import end_of_file_fixer
from pre_commit_hooks.end_of_file_fixer import fix_file
+from pre_commit_hooks.end_of_file_fixer import main
# Input, expected return value, expected output
@@ -35,7 +35,7 @@
path = tmpdir.join('file.txt')
path.write_binary(input_s)
- ret = end_of_file_fixer([path.strpath])
+ ret = main([path.strpath])
file_output = path.read_binary()
assert file_output == output
diff --git a/tests/no_commit_to_branch_test.py b/tests/no_commit_to_branch_test.py
index c275bf7..e978ba2 100644
--- a/tests/no_commit_to_branch_test.py
+++ b/tests/no_commit_to_branch_test.py
@@ -11,24 +11,24 @@
def test_other_branch(temp_git_dir):
with temp_git_dir.as_cwd():
cmd_output('git', 'checkout', '-b', 'anotherbranch')
- assert is_on_branch(('master',)) is False
+ assert is_on_branch({'master'}) is False
def test_multi_branch(temp_git_dir):
with temp_git_dir.as_cwd():
cmd_output('git', 'checkout', '-b', 'another/branch')
- assert is_on_branch(('master',)) is False
+ assert is_on_branch({'master'}) is False
def test_multi_branch_fail(temp_git_dir):
with temp_git_dir.as_cwd():
cmd_output('git', 'checkout', '-b', 'another/branch')
- assert is_on_branch(('another/branch',)) is True
+ assert is_on_branch({'another/branch'}) is True
def test_master_branch(temp_git_dir):
with temp_git_dir.as_cwd():
- assert is_on_branch(('master',)) is True
+ assert is_on_branch({'master'}) is True
def test_main_branch_call(temp_git_dir):
diff --git a/tests/pretty_format_json_test.py b/tests/pretty_format_json_test.py
index 7ce7e16..8d82d74 100644
--- a/tests/pretty_format_json_test.py
+++ b/tests/pretty_format_json_test.py
@@ -3,8 +3,8 @@
import pytest
from six import PY2
+from pre_commit_hooks.pretty_format_json import main
from pre_commit_hooks.pretty_format_json import parse_num_to_int
-from pre_commit_hooks.pretty_format_json import pretty_format_json
from testing.util import get_resource_path
@@ -23,8 +23,8 @@
('pretty_formatted_json.json', 0),
),
)
-def test_pretty_format_json(filename, expected_retval):
- ret = pretty_format_json([get_resource_path(filename)])
+def test_main(filename, expected_retval):
+ ret = main([get_resource_path(filename)])
assert ret == expected_retval
@@ -36,8 +36,8 @@
('pretty_formatted_json.json', 0),
),
)
-def test_unsorted_pretty_format_json(filename, expected_retval):
- ret = pretty_format_json(['--no-sort-keys', get_resource_path(filename)])
+def test_unsorted_main(filename, expected_retval):
+ ret = main(['--no-sort-keys', get_resource_path(filename)])
assert ret == expected_retval
@@ -51,17 +51,17 @@
('tab_pretty_formatted_json.json', 0),
),
)
-def test_tab_pretty_format_json(filename, expected_retval): # pragma: no cover
- ret = pretty_format_json(['--indent', '\t', get_resource_path(filename)])
+def test_tab_main(filename, expected_retval): # pragma: no cover
+ ret = main(['--indent', '\t', get_resource_path(filename)])
assert ret == expected_retval
-def test_non_ascii_pretty_format_json():
- ret = pretty_format_json(['--no-ensure-ascii', get_resource_path('non_ascii_pretty_formatted_json.json')])
+def test_non_ascii_main():
+ ret = main(['--no-ensure-ascii', get_resource_path('non_ascii_pretty_formatted_json.json')])
assert ret == 0
-def test_autofix_pretty_format_json(tmpdir):
+def test_autofix_main(tmpdir):
srcfile = tmpdir.join('to_be_json_formatted.json')
shutil.copyfile(
get_resource_path('not_pretty_formatted_json.json'),
@@ -69,30 +69,30 @@
)
# now launch the autofix on that file
- ret = pretty_format_json(['--autofix', srcfile.strpath])
+ ret = main(['--autofix', srcfile.strpath])
# it should have formatted it
assert ret == 1
# file was formatted (shouldn't trigger linter again)
- ret = pretty_format_json([srcfile.strpath])
+ ret = main([srcfile.strpath])
assert ret == 0
def test_orderfile_get_pretty_format():
- ret = pretty_format_json(['--top-keys=alist', get_resource_path('pretty_formatted_json.json')])
+ ret = main(['--top-keys=alist', get_resource_path('pretty_formatted_json.json')])
assert ret == 0
def test_not_orderfile_get_pretty_format():
- ret = pretty_format_json(['--top-keys=blah', get_resource_path('pretty_formatted_json.json')])
+ ret = main(['--top-keys=blah', get_resource_path('pretty_formatted_json.json')])
assert ret == 1
def test_top_sorted_get_pretty_format():
- ret = pretty_format_json(['--top-keys=01-alist,alist', get_resource_path('top_sorted_json.json')])
+ ret = main(['--top-keys=01-alist,alist', get_resource_path('top_sorted_json.json')])
assert ret == 0
-def test_badfile_pretty_format_json():
- ret = pretty_format_json([get_resource_path('ok_yaml.yaml')])
+def test_badfile_main():
+ ret = main([get_resource_path('ok_yaml.yaml')])
assert ret == 1
diff --git a/tests/requirements_txt_fixer_test.py b/tests/requirements_txt_fixer_test.py
index 437cebd..b3a7942 100644
--- a/tests/requirements_txt_fixer_test.py
+++ b/tests/requirements_txt_fixer_test.py
@@ -1,7 +1,7 @@
import pytest
from pre_commit_hooks.requirements_txt_fixer import FAIL
-from pre_commit_hooks.requirements_txt_fixer import fix_requirements_txt
+from pre_commit_hooks.requirements_txt_fixer import main
from pre_commit_hooks.requirements_txt_fixer import PASS
from pre_commit_hooks.requirements_txt_fixer import Requirement
@@ -36,7 +36,7 @@
path = tmpdir.join('file.txt')
path.write_binary(input_s)
- output_retval = fix_requirements_txt([path.strpath])
+ output_retval = main([path.strpath])
assert path.read_binary() == output
assert output_retval == expected_retval
@@ -44,7 +44,7 @@
def test_requirement_object():
top_of_file = Requirement()
- top_of_file.comments.append('#foo')
+ top_of_file.comments.append(b'#foo')
top_of_file.value = b'\n'
requirement_foo = Requirement()
diff --git a/tests/sort_simple_yaml_test.py b/tests/sort_simple_yaml_test.py
index 176d12f..72f5bec 100644
--- a/tests/sort_simple_yaml_test.py
+++ b/tests/sort_simple_yaml_test.py
@@ -110,9 +110,9 @@
lines = ['# some comment', '"a": 42', 'b: 17', '', 'c: 19']
assert first_key(lines) == 'a": 42'
- # no lines
+ # no lines (not a real situation)
lines = []
- assert first_key(lines) is None
+ assert first_key(lines) == ''
@pytest.mark.parametrize('bad_lines,good_lines,_', TEST_SORTS)
diff --git a/tests/tests_should_end_in_test_test.py b/tests/tests_should_end_in_test_test.py
index dc686a5..4eb98e7 100644
--- a/tests/tests_should_end_in_test_test.py
+++ b/tests/tests_should_end_in_test_test.py
@@ -1,36 +1,36 @@
-from pre_commit_hooks.tests_should_end_in_test import validate_files
+from pre_commit_hooks.tests_should_end_in_test import main
-def test_validate_files_all_pass():
- ret = validate_files(['foo_test.py', 'bar_test.py'])
+def test_main_all_pass():
+ ret = main(['foo_test.py', 'bar_test.py'])
assert ret == 0
-def test_validate_files_one_fails():
- ret = validate_files(['not_test_ending.py', 'foo_test.py'])
+def test_main_one_fails():
+ ret = main(['not_test_ending.py', 'foo_test.py'])
assert ret == 1
-def test_validate_files_django_all_pass():
- ret = validate_files(['--django', 'tests.py', 'test_foo.py', 'test_bar.py', 'tests/test_baz.py'])
+def test_main_django_all_pass():
+ ret = main(['--django', 'tests.py', 'test_foo.py', 'test_bar.py', 'tests/test_baz.py'])
assert ret == 0
-def test_validate_files_django_one_fails():
- ret = validate_files(['--django', 'not_test_ending.py', 'test_foo.py'])
+def test_main_django_one_fails():
+ ret = main(['--django', 'not_test_ending.py', 'test_foo.py'])
assert ret == 1
def test_validate_nested_files_django_one_fails():
- ret = validate_files(['--django', 'tests/not_test_ending.py', 'test_foo.py'])
+ ret = main(['--django', 'tests/not_test_ending.py', 'test_foo.py'])
assert ret == 1
-def test_validate_files_not_django_fails():
- ret = validate_files(['foo_test.py', 'bar_test.py', 'test_baz.py'])
+def test_main_not_django_fails():
+ ret = main(['foo_test.py', 'bar_test.py', 'test_baz.py'])
assert ret == 1
-def test_validate_files_django_fails():
- ret = validate_files(['--django', 'foo_test.py', 'test_bar.py', 'test_baz.py'])
+def test_main_django_fails():
+ ret = main(['--django', 'foo_test.py', 'test_bar.py', 'test_baz.py'])
assert ret == 1
diff --git a/tox.ini b/tox.ini
index c131e6f..d1e6a79 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,6 +1,6 @@
[tox]
# These should match the travis env list
-envlist = py27,py36,py37,pypy
+envlist = py27,py36,py37,pypy3
[testenv]
deps = -rrequirements-dev.txt