debug-statements: detect python3.7+ breakpoint()
diff --git a/pre_commit_hooks/debug_statement_hook.py b/pre_commit_hooks/debug_statement_hook.py index c5ca387..81591dd 100644 --- a/pre_commit_hooks/debug_statement_hook.py +++ b/pre_commit_hooks/debug_statement_hook.py
@@ -8,32 +8,33 @@ DEBUG_STATEMENTS = {'pdb', 'ipdb', 'pudb', 'q', 'rdb'} +Debug = collections.namedtuple('Debug', ('line', 'col', 'name', 'reason')) -DebugStatement = collections.namedtuple( - 'DebugStatement', ['name', 'line', 'col'], -) - - -class ImportStatementParser(ast.NodeVisitor): +class DebugStatementParser(ast.NodeVisitor): def __init__(self): - self.debug_import_statements = [] + self.breakpoints = [] def visit_Import(self, node): - for node_name in node.names: - if node_name.name in DEBUG_STATEMENTS: - self.debug_import_statements.append( - DebugStatement(node_name.name, node.lineno, node.col_offset), - ) + for name in node.names: + if name.name in DEBUG_STATEMENTS: + st = Debug(node.lineno, node.col_offset, name.name, 'imported') + self.breakpoints.append(st) def visit_ImportFrom(self, node): if node.module in DEBUG_STATEMENTS: - self.debug_import_statements.append( - DebugStatement(node.module, node.lineno, node.col_offset), - ) + st = Debug(node.lineno, node.col_offset, node.module, 'imported') + self.breakpoints.append(st) + + def visit_Call(self, node): + """python3.7+ breakpoint()""" + if isinstance(node.func, ast.Name) and node.func.id == 'breakpoint': + st = Debug(node.lineno, node.col_offset, node.func.id, 'called') + self.breakpoints.append(st) + self.generic_visit(node) -def check_file_for_debug_statements(filename): +def check_file(filename): try: ast_obj = ast.parse(open(filename, 'rb').read(), filename=filename) except SyntaxError: @@ -42,34 +43,30 @@ print('\t' + traceback.format_exc().replace('\n', '\n\t')) print() return 1 - visitor = ImportStatementParser() + + visitor = DebugStatementParser() visitor.visit(ast_obj) - if visitor.debug_import_statements: - for debug_statement in visitor.debug_import_statements: - print( - '{}:{}:{} - {} imported'.format( - filename, - debug_statement.line, - debug_statement.col, - debug_statement.name, - ), - ) - return 1 - else: - return 0 + + for bp in visitor.breakpoints: + print( + '{}:{}:{} - {} {}'.format( + filename, bp.line, bp.col, bp.name, bp.reason, + ), + ) + + return int(bool(visitor.breakpoints)) -def debug_statement_hook(argv=None): +def main(argv=None): parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to run') args = parser.parse_args(argv) retv = 0 for filename in args.filenames: - retv |= check_file_for_debug_statements(filename) - + retv |= check_file(filename) return retv if __name__ == '__main__': - exit(debug_statement_hook()) + exit(main())
diff --git a/setup.py b/setup.py index b9a6a04..a6f7a52 100644 --- a/setup.py +++ b/setup.py
@@ -46,7 +46,7 @@ 'check-vcs-permalinks = pre_commit_hooks.check_vcs_permalinks:main', 'check-xml = pre_commit_hooks.check_xml:check_xml', 'check-yaml = pre_commit_hooks.check_yaml:check_yaml', - 'debug-statement-hook = pre_commit_hooks.debug_statement_hook:debug_statement_hook', + 'debug-statement-hook = pre_commit_hooks.debug_statement_hook:main', 'detect-aws-credentials = pre_commit_hooks.detect_aws_credentials:main', 'detect-private-key = pre_commit_hooks.detect_private_key:detect_private_key', 'double-quote-string-fixer = pre_commit_hooks.string_fixer:main',
diff --git a/tests/debug_statement_hook_test.py b/tests/debug_statement_hook_test.py index 44c2011..d15f5f7 100644 --- a/tests/debug_statement_hook_test.py +++ b/tests/debug_statement_hook_test.py
@@ -4,48 +4,60 @@ import ast -from pre_commit_hooks.debug_statement_hook import debug_statement_hook -from pre_commit_hooks.debug_statement_hook import DebugStatement -from pre_commit_hooks.debug_statement_hook import ImportStatementParser +from pre_commit_hooks.debug_statement_hook import Debug +from pre_commit_hooks.debug_statement_hook import DebugStatementParser +from pre_commit_hooks.debug_statement_hook import main from testing.util import get_resource_path -def test_no_debug_imports(): - visitor = ImportStatementParser() +def test_no_breakpoints(): + visitor = DebugStatementParser() visitor.visit(ast.parse('import os\nfrom foo import bar\n')) - assert visitor.debug_import_statements == [] + assert visitor.breakpoints == [] def test_finds_debug_import_attribute_access(): - visitor = ImportStatementParser() + visitor = DebugStatementParser() visitor.visit(ast.parse('import ipdb; ipdb.set_trace()')) - assert visitor.debug_import_statements == [DebugStatement('ipdb', 1, 0)] + assert visitor.breakpoints == [Debug(1, 0, 'ipdb', 'imported')] def test_finds_debug_import_from_import(): - visitor = ImportStatementParser() + visitor = DebugStatementParser() visitor.visit(ast.parse('from pudb import set_trace; set_trace()')) - assert visitor.debug_import_statements == [DebugStatement('pudb', 1, 0)] + assert visitor.breakpoints == [Debug(1, 0, 'pudb', 'imported')] + + +def test_finds_breakpoint(): + visitor = DebugStatementParser() + visitor.visit(ast.parse('breakpoint()')) + assert visitor.breakpoints == [Debug(1, 0, 'breakpoint', 'called')] def test_returns_one_for_failing_file(tmpdir): f_py = tmpdir.join('f.py') f_py.write('def f():\n import pdb; pdb.set_trace()') - ret = debug_statement_hook([f_py.strpath]) + ret = main([f_py.strpath]) assert ret == 1 def test_returns_zero_for_passing_file(): - ret = debug_statement_hook([__file__]) + ret = main([__file__]) assert ret == 0 def test_syntaxerror_file(): - ret = debug_statement_hook([get_resource_path('cannot_parse_ast.notpy')]) + ret = main([get_resource_path('cannot_parse_ast.notpy')]) assert ret == 1 def test_non_utf8_file(tmpdir): f_py = tmpdir.join('f.py') f_py.write_binary('# -*- coding: cp1252 -*-\nx = "€"\n'.encode('cp1252')) - assert debug_statement_hook((f_py.strpath,)) == 0 + assert main((f_py.strpath,)) == 0 + + +def test_py37_breakpoint(tmpdir): + f_py = tmpdir.join('f.py') + f_py.write('def f():\n breakpoint()\n') + assert main((f_py.strpath,)) == 1