aboutsummaryrefslogtreecommitdiff
blob: 2696abda74321f0ccbbf006ef0cea8a1f98e9132 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
"""
Find intermediate evalutation results in assert statements through builtin AST.
This should replace oldinterpret.py eventually.
"""

import sys
import ast

import py
from _pytest.assertion import util
from _pytest.assertion.reinterpret import BuiltinAssertionError


if sys.platform.startswith("java") and sys.version_info < (2, 5, 2):
    # See http://bugs.jython.org/issue1497
    _exprs = ("BoolOp", "BinOp", "UnaryOp", "Lambda", "IfExp", "Dict",
              "ListComp", "GeneratorExp", "Yield", "Compare", "Call",
              "Repr", "Num", "Str", "Attribute", "Subscript", "Name",
              "List", "Tuple")
    _stmts = ("FunctionDef", "ClassDef", "Return", "Delete", "Assign",
              "AugAssign", "Print", "For", "While", "If", "With", "Raise",
              "TryExcept", "TryFinally", "Assert", "Import", "ImportFrom",
              "Exec", "Global", "Expr", "Pass", "Break", "Continue")
    _expr_nodes = set(getattr(ast, name) for name in _exprs)
    _stmt_nodes = set(getattr(ast, name) for name in _stmts)
    def _is_ast_expr(node):
        return node.__class__ in _expr_nodes
    def _is_ast_stmt(node):
        return node.__class__ in _stmt_nodes
else:
    def _is_ast_expr(node):
        return isinstance(node, ast.expr)
    def _is_ast_stmt(node):
        return isinstance(node, ast.stmt)


class Failure(Exception):
    """Error found while interpreting AST."""

    def __init__(self, explanation=""):
        self.cause = sys.exc_info()
        self.explanation = explanation


def interpret(source, frame, should_fail=False):
    mod = ast.parse(source)
    visitor = DebugInterpreter(frame)
    try:
        visitor.visit(mod)
    except Failure:
        failure = sys.exc_info()[1]
        return getfailure(failure)
    if should_fail:
        return ("(assertion failed, but when it was re-run for "
                "printing intermediate values, it did not fail.  Suggestions: "
                "compute assert expression before the assert or use --no-assert)")

def run(offending_line, frame=None):
    if frame is None:
        frame = py.code.Frame(sys._getframe(1))
    return interpret(offending_line, frame)

def getfailure(e):
    explanation = util.format_explanation(e.explanation)
    value = e.cause[1]
    if str(value):
        lines = explanation.split('\n')
        lines[0] += "  << %s" % (value,)
        explanation = '\n'.join(lines)
    text = "%s: %s" % (e.cause[0].__name__, explanation)
    if text.startswith('AssertionError: assert '):
        text = text[16:]
    return text

operator_map = {
    ast.BitOr : "|",
    ast.BitXor : "^",
    ast.BitAnd : "&",
    ast.LShift : "<<",
    ast.RShift : ">>",
    ast.Add : "+",
    ast.Sub : "-",
    ast.Mult : "*",
    ast.Div : "/",
    ast.FloorDiv : "//",
    ast.Mod : "%",
    ast.Eq : "==",
    ast.NotEq : "!=",
    ast.Lt : "<",
    ast.LtE : "<=",
    ast.Gt : ">",
    ast.GtE : ">=",
    ast.Pow : "**",
    ast.Is : "is",
    ast.IsNot : "is not",
    ast.In : "in",
    ast.NotIn : "not in"
}

unary_map = {
    ast.Not : "not %s",
    ast.Invert : "~%s",
    ast.USub : "-%s",
    ast.UAdd : "+%s"
}


class DebugInterpreter(ast.NodeVisitor):
    """Interpret AST nodes to gleam useful debugging information. """

    def __init__(self, frame):
        self.frame = frame

    def generic_visit(self, node):
        # Fallback when we don't have a special implementation.
        if _is_ast_expr(node):
            mod = ast.Expression(node)
            co = self._compile(mod)
            try:
                result = self.frame.eval(co)
            except Exception:
                raise Failure()
            explanation = self.frame.repr(result)
            return explanation, result
        elif _is_ast_stmt(node):
            mod = ast.Module([node])
            co = self._compile(mod, "exec")
            try:
                self.frame.exec_(co)
            except Exception:
                raise Failure()
            return None, None
        else:
            raise AssertionError("can't handle %s" %(node,))

    def _compile(self, source, mode="eval"):
        return compile(source, "<assertion interpretation>", mode)

    def visit_Expr(self, expr):
        return self.visit(expr.value)

    def visit_Module(self, mod):
        for stmt in mod.body:
            self.visit(stmt)

    def visit_Name(self, name):
        explanation, result = self.generic_visit(name)
        # See if the name is local.
        source = "%r in locals() is not globals()" % (name.id,)
        co = self._compile(source)
        try:
            local = self.frame.eval(co)
        except Exception:
            # have to assume it isn't
            local = None
        if local is None or not self.frame.is_true(local):
            return name.id, result
        return explanation, result

    def visit_Compare(self, comp):
        left = comp.left
        left_explanation, left_result = self.visit(left)
        for op, next_op in zip(comp.ops, comp.comparators):
            next_explanation, next_result = self.visit(next_op)
            op_symbol = operator_map[op.__class__]
            explanation = "%s %s %s" % (left_explanation, op_symbol,
                                        next_explanation)
            source = "__exprinfo_left %s __exprinfo_right" % (op_symbol,)
            co = self._compile(source)
            try:
                result = self.frame.eval(co, __exprinfo_left=left_result,
                                         __exprinfo_right=next_result)
            except Exception:
                raise Failure(explanation)
            try:
                if not self.frame.is_true(result):
                    break
            except KeyboardInterrupt:
                raise
            except:
                break
            left_explanation, left_result = next_explanation, next_result

        if util._reprcompare is not None:
            res = util._reprcompare(op_symbol, left_result, next_result)
            if res:
                explanation = res
        return explanation, result

    def visit_BoolOp(self, boolop):
        is_or = isinstance(boolop.op, ast.Or)
        explanations = []
        for operand in boolop.values:
            explanation, result = self.visit(operand)
            explanations.append(explanation)
            if result == is_or:
                break
        name = is_or and " or " or " and "
        explanation = "(" + name.join(explanations) + ")"
        return explanation, result

    def visit_UnaryOp(self, unary):
        pattern = unary_map[unary.op.__class__]
        operand_explanation, operand_result = self.visit(unary.operand)
        explanation = pattern % (operand_explanation,)
        co = self._compile(pattern % ("__exprinfo_expr",))
        try:
            result = self.frame.eval(co, __exprinfo_expr=operand_result)
        except Exception:
            raise Failure(explanation)
        return explanation, result

    def visit_BinOp(self, binop):
        left_explanation, left_result = self.visit(binop.left)
        right_explanation, right_result = self.visit(binop.right)
        symbol = operator_map[binop.op.__class__]
        explanation = "(%s %s %s)" % (left_explanation, symbol,
                                      right_explanation)
        source = "__exprinfo_left %s __exprinfo_right" % (symbol,)
        co = self._compile(source)
        try:
            result = self.frame.eval(co, __exprinfo_left=left_result,
                                     __exprinfo_right=right_result)
        except Exception:
            raise Failure(explanation)
        return explanation, result

    def visit_Call(self, call):
        func_explanation, func = self.visit(call.func)
        arg_explanations = []
        ns = {"__exprinfo_func" : func}
        arguments = []
        for arg in call.args:
            arg_explanation, arg_result = self.visit(arg)
            arg_name = "__exprinfo_%s" % (len(ns),)
            ns[arg_name] = arg_result
            arguments.append(arg_name)
            arg_explanations.append(arg_explanation)
        for keyword in call.keywords:
            arg_explanation, arg_result = self.visit(keyword.value)
            arg_name = "__exprinfo_%s" % (len(ns),)
            ns[arg_name] = arg_result
            keyword_source = "%s=%%s" % (keyword.arg)
            arguments.append(keyword_source % (arg_name,))
            arg_explanations.append(keyword_source % (arg_explanation,))
        if call.starargs:
            arg_explanation, arg_result = self.visit(call.starargs)
            arg_name = "__exprinfo_star"
            ns[arg_name] = arg_result
            arguments.append("*%s" % (arg_name,))
            arg_explanations.append("*%s" % (arg_explanation,))
        if call.kwargs:
            arg_explanation, arg_result = self.visit(call.kwargs)
            arg_name = "__exprinfo_kwds"
            ns[arg_name] = arg_result
            arguments.append("**%s" % (arg_name,))
            arg_explanations.append("**%s" % (arg_explanation,))
        args_explained = ", ".join(arg_explanations)
        explanation = "%s(%s)" % (func_explanation, args_explained)
        args = ", ".join(arguments)
        source = "__exprinfo_func(%s)" % (args,)
        co = self._compile(source)
        try:
            result = self.frame.eval(co, **ns)
        except Exception:
            raise Failure(explanation)
        pattern = "%s\n{%s = %s\n}"
        rep = self.frame.repr(result)
        explanation = pattern % (rep, rep, explanation)
        return explanation, result

    def _is_builtin_name(self, name):
        pattern = "%r not in globals() and %r not in locals()"
        source = pattern % (name.id, name.id)
        co = self._compile(source)
        try:
            return self.frame.eval(co)
        except Exception:
            return False

    def visit_Attribute(self, attr):
        if not isinstance(attr.ctx, ast.Load):
            return self.generic_visit(attr)
        source_explanation, source_result = self.visit(attr.value)
        explanation = "%s.%s" % (source_explanation, attr.attr)
        source = "__exprinfo_expr.%s" % (attr.attr,)
        co = self._compile(source)
        try:
            result = self.frame.eval(co, __exprinfo_expr=source_result)
        except Exception:
            raise Failure(explanation)
        explanation = "%s\n{%s = %s.%s\n}" % (self.frame.repr(result),
                                              self.frame.repr(result),
                                              source_explanation, attr.attr)
        # Check if the attr is from an instance.
        source = "%r in getattr(__exprinfo_expr, '__dict__', {})"
        source = source % (attr.attr,)
        co = self._compile(source)
        try:
            from_instance = self.frame.eval(co, __exprinfo_expr=source_result)
        except Exception:
            from_instance = None
        if from_instance is None or self.frame.is_true(from_instance):
            rep = self.frame.repr(result)
            pattern = "%s\n{%s = %s\n}"
            explanation = pattern % (rep, rep, explanation)
        return explanation, result

    def visit_Assert(self, assrt):
        test_explanation, test_result = self.visit(assrt.test)
        explanation = "assert %s" % (test_explanation,)
        if not self.frame.is_true(test_result):
            try:
                raise BuiltinAssertionError
            except Exception:
                raise Failure(explanation)
        return explanation, test_result

    def visit_Assign(self, assign):
        value_explanation, value_result = self.visit(assign.value)
        explanation = "... = %s" % (value_explanation,)
        name = ast.Name("__exprinfo_expr", ast.Load(),
                        lineno=assign.value.lineno,
                        col_offset=assign.value.col_offset)
        new_assign = ast.Assign(assign.targets, name, lineno=assign.lineno,
                                col_offset=assign.col_offset)
        mod = ast.Module([new_assign])
        co = self._compile(mod, "exec")
        try:
            self.frame.exec_(co, __exprinfo_expr=value_result)
        except Exception:
            raise Failure(explanation)
        return explanation, value_result