summaryrefslogtreecommitdiff
path: root/ujit_core.c
blob: 4a9ddc334daf23da5ab889e748bc2f4cf3e2bc18 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
#include "internal.h"
#include "ujit_asm.h"
#include "ujit_iface.h"
#include "ujit_core.h"
#include "ujit_codegen.h"

// Maximum number of branch instructions we can track
#define MAX_BRANCHES 32768

// Table of block versions indexed by (iseq, index) tuples
st_table * version_tbl;

// Registered branch entries
branch_t branch_entries[MAX_BRANCHES];
uint32_t num_branches = 0;

// Get the current instruction opcode from the context object
int
ctx_get_opcode(ctx_t *ctx)
{
    return opcode_at_pc(ctx->iseq, ctx->pc);
}

// Get an instruction argument from the context object
VALUE
ctx_get_arg(ctx_t* ctx, size_t arg_idx)
{
    assert (arg_idx + 1 < insn_len(ctx_get_opcode(ctx)));
    return *(ctx->pc + arg_idx + 1);
}

/*
Get an operand for the adjusted stack pointer address
*/
x86opnd_t
ctx_sp_opnd(ctx_t* ctx, int32_t offset_bytes)
{
    int32_t offset = (ctx->stack_size) * 8 + offset_bytes;
    return mem_opnd(64, REG_SP, offset);
}

/*
Make space on the stack for N values
Return a pointer to the new stack top
*/
x86opnd_t
ctx_stack_push(ctx_t* ctx, size_t n)
{
    ctx->stack_size += n;

    // SP points just above the topmost value
    int32_t offset = (ctx->stack_size - 1) * 8;
    return mem_opnd(64, REG_SP, offset);
}

/*
Pop N values off the stack
Return a pointer to the stack top before the pop operation
*/
x86opnd_t
ctx_stack_pop(ctx_t* ctx, size_t n)
{
    // SP points just above the topmost value
    int32_t offset = (ctx->stack_size - 1) * 8;
    x86opnd_t top = mem_opnd(64, REG_SP, offset);

    ctx->stack_size -= n;

    return top;
}

x86opnd_t
ctx_stack_opnd(ctx_t* ctx, int32_t idx)
{
    // SP points just above the topmost value
    int32_t offset = (ctx->stack_size - 1 - idx) * 8;
    x86opnd_t opnd = mem_opnd(64, REG_SP, offset);

    return opnd;
}

int blockid_cmp(st_data_t arg0, st_data_t arg1)
{
    const blockid_t *block0 = (const blockid_t*)arg0;
    const blockid_t *block1 = (const blockid_t*)arg1;
    return block0->iseq == block1->iseq && block0->idx == block1->idx;
}

st_index_t blockid_hash(st_data_t arg)
{
    const blockid_t *blockid = (const blockid_t*)arg;
    st_index_t hash0 = st_numhash((st_data_t)blockid->iseq);
    st_index_t hash1 = st_numhash((st_data_t)(uint64_t)blockid->idx);

    // Use XOR to combine the hashes
    return hash0 ^ hash1;
}

static const struct st_hash_type hashtype_blockid = {
    blockid_cmp,
    blockid_hash,
};

// Retrieve a basic block version for an (iseq, idx) tuple
uint8_t* find_block_version(blockid_t block)
{
    // If there exists a version for this block id
    st_data_t st_version;
    if (rb_st_lookup(version_tbl, (st_data_t)&block, &st_version)) {
        return (uint8_t*)st_version;
    }

    return NULL;
}

// Called by the generated code when a branch stub is executed
// Triggers compilation of branches and code patching
uint8_t* branch_stub_hit(uint32_t branch_idx, uint32_t target_idx)
{
    assert (branch_idx < num_branches);
    assert (target_idx < 2);
    branch_t branch = branch_entries[branch_idx];
    blockid_t target = branch.targets[target_idx];

    // If either of the target blocks will be placed next
    if (cb->write_pos == branch.end_pos)
    {
        branch.shape = (uint8_t)target_idx;

        // Rewrite the branch with the new, potentially more compact shape
        cb_set_pos(cb, branch.start_pos);
        branch.gen_fn(cb, branch.dst_addrs[0], branch.dst_addrs[1], branch.shape);
        assert (cb->write_pos <= branch.end_pos);
    }

    // Try to find a compiled version of this block
    uint8_t* code_ptr = find_block_version(target);

    // If this block hasn't yet been compiled
    if (!code_ptr)
    {
        code_ptr = ujit_compile_block(target.iseq, target.idx, false);
        st_insert(version_tbl, (st_data_t)&target, (st_data_t)code_ptr);
        branch.dst_addrs[target_idx] = code_ptr;
    }

    // Rewrite the branch with the new jump target address
    size_t cur_pos = cb->write_pos;
    cb_set_pos(cb, branch.start_pos);
    branch.gen_fn(cb, branch.dst_addrs[0], branch.dst_addrs[1], branch.shape);
    assert (cb->write_pos <= branch.end_pos);
    cb_set_pos(cb, cur_pos);

    // Return a pointer to the compiled block version
    return code_ptr;
}

// Get a version or stub corresponding to a branch target
// TODO: need incoming and target versioning contexts
uint8_t* get_branch_target(codeblock_t* ocb, blockid_t target, uint32_t branch_idx, uint32_t target_idx)
{
    uint8_t* block_code = find_block_version(target);

    if (block_code)
        return block_code;

    uint8_t* stub_addr = cb_get_ptr(ocb, ocb->write_pos);

    // Generate an outlined stub that will call
    // branch_stub_hit(uint32_t branch_idx, uint32_t target_idx)
    mov(ocb, RDI, imm_opnd(branch_idx));
    mov(ocb, RSI, imm_opnd(target_idx));
    call_ptr(ocb, REG0, (void *)&branch_stub_hit);

    // Jump to the address returned by the
    // branch_stub_hit call
    jmp_rm(ocb, RAX);

    return stub_addr;
}

void gen_branch(codeblock_t* cb, codeblock_t* ocb, blockid_t target0, blockid_t target1, branchgen_fn gen_fn)
{
    // Get branch targets or stubs (code pointers)
    uint8_t* dst_addr0 = get_branch_target(ocb, target0, num_branches, 0);
    uint8_t* dst_addr1 = get_branch_target(ocb, target1, num_branches, 1);

    uint32_t start_pos = (uint32_t)cb->write_pos;

    // Call the branch generation function
    gen_fn(cb, dst_addr0, dst_addr1, SHAPE_DEFAULT);

    uint32_t end_pos = (uint32_t)cb->write_pos;

    // Register this branch entry
    branch_t branch_entry = {
        start_pos,
        end_pos,
        { target0, target1 },
        { dst_addr0, dst_addr1 },
        gen_fn,
        SHAPE_DEFAULT
    };

    assert (num_branches < MAX_BRANCHES);
    branch_entries[num_branches] = branch_entry;
    num_branches++;
}

void
ujit_init_core(void)
{
    // Initialize the version hash table
    version_tbl = st_init_table(&hashtype_blockid);
}