| """Astroid hooks for various builtins.""" |
| |
| import sys |
| from functools import partial |
| from textwrap import dedent |
| |
| import six |
| from astroid import (MANAGER, UseInferenceDefault, |
| inference_tip, YES, InferenceError, UnresolvableName) |
| from astroid import nodes |
| from astroid.builder import AstroidBuilder |
| |
| |
| def _extend_str(class_node, rvalue): |
| """function to extend builtin str/unicode class""" |
| # TODO(cpopa): this approach will make astroid to believe |
| # that some arguments can be passed by keyword, but |
| # unfortunately, strings and bytes don't accept keyword arguments. |
| code = dedent(''' |
| class whatever(object): |
| def join(self, iterable): |
| return {rvalue} |
| def replace(self, old, new, count=None): |
| return {rvalue} |
| def format(self, *args, **kwargs): |
| return {rvalue} |
| def encode(self, encoding='ascii', errors=None): |
| return '' |
| def decode(self, encoding='ascii', errors=None): |
| return u'' |
| def capitalize(self): |
| return {rvalue} |
| def title(self): |
| return {rvalue} |
| def lower(self): |
| return {rvalue} |
| def upper(self): |
| return {rvalue} |
| def swapcase(self): |
| return {rvalue} |
| def index(self, sub, start=None, end=None): |
| return 0 |
| def find(self, sub, start=None, end=None): |
| return 0 |
| def count(self, sub, start=None, end=None): |
| return 0 |
| def strip(self, chars=None): |
| return {rvalue} |
| def lstrip(self, chars=None): |
| return {rvalue} |
| def rstrip(self, chars=None): |
| return {rvalue} |
| def rjust(self, width, fillchar=None): |
| return {rvalue} |
| def center(self, width, fillchar=None): |
| return {rvalue} |
| def ljust(self, width, fillchar=None): |
| return {rvalue} |
| ''') |
| code = code.format(rvalue=rvalue) |
| fake = AstroidBuilder(MANAGER).string_build(code)['whatever'] |
| for method in fake.mymethods(): |
| class_node.locals[method.name] = [method] |
| method.parent = class_node |
| |
| def extend_builtins(class_transforms): |
| from astroid.bases import BUILTINS |
| builtin_ast = MANAGER.astroid_cache[BUILTINS] |
| for class_name, transform in class_transforms.items(): |
| transform(builtin_ast[class_name]) |
| |
| if sys.version_info > (3, 0): |
| extend_builtins({'bytes': partial(_extend_str, rvalue="b''"), |
| 'str': partial(_extend_str, rvalue="''")}) |
| else: |
| extend_builtins({'str': partial(_extend_str, rvalue="''"), |
| 'unicode': partial(_extend_str, rvalue="u''")}) |
| |
| |
| def register_builtin_transform(transform, builtin_name): |
| """Register a new transform function for the given *builtin_name*. |
| |
| The transform function must accept two parameters, a node and |
| an optional context. |
| """ |
| def _transform_wrapper(node, context=None): |
| result = transform(node, context=context) |
| if result: |
| result.parent = node |
| result.lineno = node.lineno |
| result.col_offset = node.col_offset |
| return iter([result]) |
| |
| MANAGER.register_transform(nodes.CallFunc, |
| inference_tip(_transform_wrapper), |
| lambda n: (isinstance(n.func, nodes.Name) and |
| n.func.name == builtin_name)) |
| |
| |
| def _generic_inference(node, context, node_type, transform): |
| args = node.args |
| if not args: |
| return node_type() |
| if len(node.args) > 1: |
| raise UseInferenceDefault() |
| |
| arg, = args |
| transformed = transform(arg) |
| if not transformed: |
| try: |
| infered = next(arg.infer(context=context)) |
| except (InferenceError, StopIteration): |
| raise UseInferenceDefault() |
| if infered is YES: |
| raise UseInferenceDefault() |
| transformed = transform(infered) |
| if not transformed or transformed is YES: |
| raise UseInferenceDefault() |
| return transformed |
| |
| |
| def _generic_transform(arg, klass, iterables, build_elts): |
| if isinstance(arg, klass): |
| return arg |
| elif isinstance(arg, iterables): |
| if not all(isinstance(elt, nodes.Const) |
| for elt in arg.elts): |
| # TODO(cpopa): Don't support heterogenous elements. |
| # Not yet, though. |
| raise UseInferenceDefault() |
| elts = [elt.value for elt in arg.elts] |
| elif isinstance(arg, nodes.Dict): |
| if not all(isinstance(elt[0], nodes.Const) |
| for elt in arg.items): |
| raise UseInferenceDefault() |
| elts = [item[0].value for item in arg.items] |
| elif (isinstance(arg, nodes.Const) and |
| isinstance(arg.value, (six.string_types, six.binary_type))): |
| elts = arg.value |
| else: |
| return |
| return klass(elts=build_elts(elts)) |
| |
| |
| def _infer_builtin(node, context, |
| klass=None, iterables=None, |
| build_elts=None): |
| transform_func = partial( |
| _generic_transform, |
| klass=klass, |
| iterables=iterables, |
| build_elts=build_elts) |
| |
| return _generic_inference(node, context, klass, transform_func) |
| |
| # pylint: disable=invalid-name |
| infer_tuple = partial( |
| _infer_builtin, |
| klass=nodes.Tuple, |
| iterables=(nodes.List, nodes.Set), |
| build_elts=tuple) |
| |
| infer_list = partial( |
| _infer_builtin, |
| klass=nodes.List, |
| iterables=(nodes.Tuple, nodes.Set), |
| build_elts=list) |
| |
| infer_set = partial( |
| _infer_builtin, |
| klass=nodes.Set, |
| iterables=(nodes.List, nodes.Tuple), |
| build_elts=set) |
| |
| |
| def _get_elts(arg, context): |
| is_iterable = lambda n: isinstance(n, |
| (nodes.List, nodes.Tuple, nodes.Set)) |
| try: |
| infered = next(arg.infer(context)) |
| except (InferenceError, UnresolvableName): |
| raise UseInferenceDefault() |
| if isinstance(infered, nodes.Dict): |
| items = infered.items |
| elif is_iterable(infered): |
| items = [] |
| for elt in infered.elts: |
| # If an item is not a pair of two items, |
| # then fallback to the default inference. |
| # Also, take in consideration only hashable items, |
| # tuples and consts. We are choosing Names as well. |
| if not is_iterable(elt): |
| raise UseInferenceDefault() |
| if len(elt.elts) != 2: |
| raise UseInferenceDefault() |
| if not isinstance(elt.elts[0], |
| (nodes.Tuple, nodes.Const, nodes.Name)): |
| raise UseInferenceDefault() |
| items.append(tuple(elt.elts)) |
| else: |
| raise UseInferenceDefault() |
| return items |
| |
| def infer_dict(node, context=None): |
| """Try to infer a dict call to a Dict node. |
| |
| The function treats the following cases: |
| |
| * dict() |
| * dict(mapping) |
| * dict(iterable) |
| * dict(iterable, **kwargs) |
| * dict(mapping, **kwargs) |
| * dict(**kwargs) |
| |
| If a case can't be infered, we'll fallback to default inference. |
| """ |
| has_keywords = lambda args: all(isinstance(arg, nodes.Keyword) |
| for arg in args) |
| if not node.args and not node.kwargs: |
| # dict() |
| return nodes.Dict() |
| elif has_keywords(node.args) and node.args: |
| # dict(a=1, b=2, c=4) |
| items = [(nodes.Const(arg.arg), arg.value) for arg in node.args] |
| elif (len(node.args) >= 2 and |
| has_keywords(node.args[1:])): |
| # dict(some_iterable, b=2, c=4) |
| elts = _get_elts(node.args[0], context) |
| keys = [(nodes.Const(arg.arg), arg.value) for arg in node.args[1:]] |
| items = elts + keys |
| elif len(node.args) == 1: |
| items = _get_elts(node.args[0], context) |
| else: |
| raise UseInferenceDefault() |
| |
| empty = nodes.Dict() |
| empty.items = items |
| return empty |
| |
| # Builtins inference |
| register_builtin_transform(infer_tuple, 'tuple') |
| register_builtin_transform(infer_set, 'set') |
| register_builtin_transform(infer_list, 'list') |
| register_builtin_transform(infer_dict, 'dict') |