aboutsummaryrefslogtreecommitdiff
blob: 5947e9777c4b487eb305583c3df67013fe173d09 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from rpython.flowspace.model import Constant, Variable, mkentrymap
from rpython.tool.ansi_print import AnsiLogger

log = AnsiLogger("backendopt")


def is_chain_block(block, first=False):
    if len(block.operations) == 0:
        return False
    if len(block.operations) > 1 and not first:
        return False
    op = block.operations[-1]
    if (op.opname not in ('int_eq', 'uint_eq', 'char_eq', 'unichar_eq')
        # note: 'llong_eq', 'ullong_eq' have been removed, as it's not
        # strictly C-compliant to do a switch() on a long long.  It also
        # crashes the JIT, and it's very very rare anyway.
        or op.result != block.exitswitch):
        return False
    if isinstance(op.args[0], Variable) and isinstance(op.args[1], Variable):
        return False
    if isinstance(op.args[0], Constant) and isinstance(op.args[1], Constant):
        return False
    # check that the constant is hashable (ie not a symbolic)
    try:
        if isinstance(op.args[0], Constant):
            hash(op.args[0].value)
        else:
            hash(op.args[1].value)
    except TypeError:
        return False
    return True

def merge_chain(chain, checkvar, varmap, graph):
    def get_new_arg(var_or_const):
        if isinstance(var_or_const, Constant):
            return var_or_const
        return varmap[var_or_const]
    firstblock, case = chain[0]
    firstblock.operations = firstblock.operations[:-1]
    firstblock.exitswitch = checkvar
    values = {}
    links = []
    default = chain[-1][0].exits[0]
    default.exitcase = "default"
    default.llexitcase = None
    default.args = [get_new_arg(arg) for arg in default.args]
    for block, case in chain:
        if case.value in values:
            # - ignore silently: it occurs in platform-dependent
            #   chains of tests, for example
            #log.WARNING("unreachable code with value %r in graph %s" % (
            #            case.value, graph))
            continue
        values[case.value] = True
        link = block.exits[1]
        links.append(link)
        link.exitcase = case.value
        link.llexitcase = case.value
        link.args = [get_new_arg(arg) for arg in link.args]
    links.append(default)
    firstblock.recloseblock(*links)

def merge_if_blocks_once(graph):
    """Convert consecutive blocks that all compare a variable (of Primitive type)
    with a constant into one block with multiple exits. The backends can in
    turn output this block as a switch statement.
    """
    candidates = [block for block in graph.iterblocks()
                      if is_chain_block(block, first=True)]
    entrymap = mkentrymap(graph)
    for firstblock in candidates:
        chain = []
        checkvars = []
        varmap = {}  # {var in a block in the chain: var in the first block}
        for var in firstblock.exits[0].args:
            varmap[var] = var
        for var in firstblock.exits[1].args:
            varmap[var] = var
        def add_to_varmap(var, newvar):
            if isinstance(var, Variable):
                varmap[newvar] = varmap[var]
            else:
                varmap[newvar] = var
        current = firstblock
        while 1:
            # check whether the chain can be extended with the block that follows the
            # False link
            checkvar = [var for var in current.operations[-1].args
                           if isinstance(var, Variable)][0]
            resvar = current.operations[-1].result
            case = [var for var in current.operations[-1].args
                       if isinstance(var, Constant)][0]
            checkvars.append(checkvar)
            falseexit = current.exits[0]
            assert not falseexit.exitcase
            trueexit = current.exits[1]
            targetblock = falseexit.target
            # if the result of the check is also passed through the link, we
            # cannot construct the chain
            if resvar in falseexit.args or resvar in trueexit.args:
                break
            chain.append((current, case))
            if len(entrymap[targetblock]) != 1:
                break
            if checkvar not in falseexit.args:
                break
            newcheckvar = targetblock.inputargs[falseexit.args.index(checkvar)]
            if not is_chain_block(targetblock):
                break
            if newcheckvar not in targetblock.operations[0].args:
                break
            for i, var in enumerate(trueexit.args):
                add_to_varmap(var, trueexit.target.inputargs[i])
            for i, var in enumerate(falseexit.args):
                add_to_varmap(var, falseexit.target.inputargs[i])
            current = targetblock
        if len(chain) > 1:
            break
    else:
        return False
    merge_chain(chain, checkvars[0], varmap, graph)
    return True

def merge_if_blocks(graph, verbose=True):
    merge = False
    while merge_if_blocks_once(graph):
        merge = True
    if merge:
        if verbose:
            log("merging blocks in %s" % (graph.name, ))
        else:
            log.dot()