Merge pull request #1202 from pre-commit/fix-builtin-literal-check
fix nested calls for check-builtin-literals
diff --git a/pre_commit_hooks/check_builtin_literals.py b/pre_commit_hooks/check_builtin_literals.py
index 16d59b5..e128eea 100644
--- a/pre_commit_hooks/check_builtin_literals.py
+++ b/pre_commit_hooks/check_builtin_literals.py
@@ -26,36 +26,37 @@
class Visitor(ast.NodeVisitor):
def __init__(
self,
- ignore: Sequence[str] | None = None,
+ ignore: set[str],
allow_dict_kwargs: bool = True,
) -> None:
self.builtin_type_calls: list[Call] = []
- self.ignore = set(ignore) if ignore else set()
self.allow_dict_kwargs = allow_dict_kwargs
+ self._disallowed = BUILTIN_TYPES.keys() - ignore
def _check_dict_call(self, node: ast.Call) -> bool:
return self.allow_dict_kwargs and bool(node.keywords)
def visit_Call(self, node: ast.Call) -> None:
- if not isinstance(node.func, ast.Name):
+ if (
# Ignore functions that are object attributes (`foo.bar()`).
# Assume that if the user calls `builtins.list()`, they know what
# they're doing.
- return
- if node.func.id not in set(BUILTIN_TYPES).difference(self.ignore):
- return
- if node.func.id == 'dict' and self._check_dict_call(node):
- return
- elif node.args:
- return
- self.builtin_type_calls.append(
- Call(node.func.id, node.lineno, node.col_offset),
- )
+ isinstance(node.func, ast.Name) and
+ node.func.id in self._disallowed and
+ (node.func.id != 'dict' or not self._check_dict_call(node)) and
+ not node.args
+ ):
+ self.builtin_type_calls.append(
+ Call(node.func.id, node.lineno, node.col_offset),
+ )
+
+ self.generic_visit(node)
def check_file(
filename: str,
- ignore: Sequence[str] | None = None,
+ *,
+ ignore: set[str],
allow_dict_kwargs: bool = True,
) -> list[Call]:
with open(filename, 'rb') as f:
diff --git a/tests/check_builtin_literals_test.py b/tests/check_builtin_literals_test.py
index 1b18257..de29063 100644
--- a/tests/check_builtin_literals_test.py
+++ b/tests/check_builtin_literals_test.py
@@ -38,11 +38,6 @@
'''
-@pytest.fixture
-def visitor():
- return Visitor()
-
-
@pytest.mark.parametrize(
('expression', 'calls'),
[
@@ -85,7 +80,8 @@
('builtins.tuple()', []),
],
)
-def test_non_dict_exprs(visitor, expression, calls):
+def test_non_dict_exprs(expression, calls):
+ visitor = Visitor(ignore=set())
visitor.visit(ast.parse(expression))
assert visitor.builtin_type_calls == calls
@@ -102,7 +98,8 @@
('builtins.dict()', []),
],
)
-def test_dict_allow_kwargs_exprs(visitor, expression, calls):
+def test_dict_allow_kwargs_exprs(expression, calls):
+ visitor = Visitor(ignore=set())
visitor.visit(ast.parse(expression))
assert visitor.builtin_type_calls == calls
@@ -114,17 +111,18 @@
('dict(a=1, b=2, c=3)', [Call('dict', 1, 0)]),
("dict(**{'a': 1, 'b': 2, 'c': 3})", [Call('dict', 1, 0)]),
('builtins.dict()', []),
+ pytest.param('f(dict())', [Call('dict', 1, 2)], id='nested'),
],
)
def test_dict_no_allow_kwargs_exprs(expression, calls):
- visitor = Visitor(allow_dict_kwargs=False)
+ visitor = Visitor(ignore=set(), allow_dict_kwargs=False)
visitor.visit(ast.parse(expression))
assert visitor.builtin_type_calls == calls
def test_ignore_constructors():
visitor = Visitor(
- ignore=('complex', 'dict', 'float', 'int', 'list', 'str', 'tuple'),
+ ignore={'complex', 'dict', 'float', 'int', 'list', 'str', 'tuple'},
)
visitor.visit(ast.parse(BUILTIN_CONSTRUCTORS))
assert visitor.builtin_type_calls == []