summaryrefslogtreecommitdiff
path: root/ujit_core.c
blob: 2cb1c252deb9ba6c277a7e9e3735e53eed52d3ca (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
#include "vm_core.h"
#include "vm_callinfo.h"
#include "builtin.h"
#include "insns.inc"
#include "insns_info.inc"
#include "ujit_asm.h"
#include "ujit_utils.h"
#include "ujit_iface.h"
#include "ujit_core.h"
#include "ujit_codegen.h"

// Maximum number of branch instructions we can track
#define MAX_BRANCHES 32768

// Table of block versions indexed by (iseq, index) tuples
st_table * version_tbl;

// Registered branch entries
branch_t branch_entries[MAX_BRANCHES];
uint32_t num_branches = 0;

/*
Get an operand for the adjusted stack pointer address
*/
x86opnd_t
ctx_sp_opnd(ctx_t* ctx, int32_t offset_bytes)
{
    int32_t offset = (ctx->stack_size) * 8 + offset_bytes;
    return mem_opnd(64, REG_SP, offset);
}

/*
Make space on the stack for N values
Return a pointer to the new stack top
*/
x86opnd_t
ctx_stack_push(ctx_t* ctx, size_t n)
{
    ctx->stack_size += n;

    // SP points just above the topmost value
    int32_t offset = (ctx->stack_size - 1) * 8;
    return mem_opnd(64, REG_SP, offset);
}

/*
Pop N values off the stack
Return a pointer to the stack top before the pop operation
*/
x86opnd_t
ctx_stack_pop(ctx_t* ctx, size_t n)
{
    // SP points just above the topmost value
    int32_t offset = (ctx->stack_size - 1) * 8;
    x86opnd_t top = mem_opnd(64, REG_SP, offset);

    ctx->stack_size -= n;

    return top;
}

x86opnd_t
ctx_stack_opnd(ctx_t* ctx, int32_t idx)
{
    // SP points just above the topmost value
    int32_t offset = (ctx->stack_size - 1 - idx) * 8;
    x86opnd_t opnd = mem_opnd(64, REG_SP, offset);

    return opnd;
}

// Retrieve a basic block version for an (iseq, idx) tuple
version_t* find_block_version(blockid_t block, const ctx_t* ctx)
{
    // If there exists a version for this block id
    st_data_t st_version;
    if (rb_st_lookup(version_tbl, (st_data_t)&block, &st_version)) {
        return (version_t*)st_version;
    }

    //
    // TODO: use the ctx parameter to search available versions
    //

    return NULL;
}
// Compile a new block version immediately
version_t* gen_block_version(blockid_t blockid, const ctx_t* ctx)
{
    // Allocate a version object
    version_t* p_version = malloc(sizeof(version_t));
    memcpy(&p_version->blockid, &blockid, sizeof(blockid_t));
    memcpy(&p_version->ctx, ctx, sizeof(ctx_t));
    p_version->incoming = NULL;
    p_version->num_incoming = 0;

    // Compile the block version
    p_version->start_pos = cb->write_pos;
    ujit_gen_code(p_version);
    p_version->end_pos = cb->write_pos;

    // Keep track of the new block version
    st_insert(version_tbl, (st_data_t)&p_version->blockid, (st_data_t)p_version);

    return p_version;
}

// Generate a block version that is an entry point inserted into an iseq
uint8_t* gen_entry_point(const rb_iseq_t *iseq, uint32_t insn_idx)
{
    // Allocate a version object
    version_t* p_version = malloc(sizeof(version_t));
    blockid_t blockid = { iseq, insn_idx };
    memcpy(&p_version->blockid, &blockid, sizeof(blockid_t));
    p_version->incoming = NULL;
    p_version->num_incoming = 0;

    // The entry context makes no assumptions about types
    ctx_t ctx = { 0 };
    memcpy(&p_version->ctx, &ctx, sizeof(ctx_t));

    // Compile the block version
    p_version->start_pos = cb->write_pos;
    uint8_t* code_ptr = ujit_gen_entry(p_version);
    p_version->end_pos = cb->write_pos;

    // If we couldn't generate any code
    if (!code_ptr)
    {
        free(p_version);
        return NULL;
    }

    // Keep track of the new block version
    st_insert(version_tbl, (st_data_t)&p_version->blockid, (st_data_t)p_version);

    return code_ptr;
}

// Add an incoming branch for a given block version
static void add_incoming(version_t* p_version, uint32_t branch_idx)
{
    // Add this branch to the list of incoming branches for the target
    uint32_t* new_list = malloc(sizeof(uint32_t) * p_version->num_incoming + 1);
    memcpy(new_list, p_version->incoming, p_version->num_incoming);
    new_list[p_version->num_incoming] = branch_idx;
    p_version->incoming = new_list;
    p_version->num_incoming += 1;
    //fprintf(stderr, "num_incoming: %d\n", p_version->num_incoming);
}

// Called by the generated code when a branch stub is executed
// Triggers compilation of branches and code patching
uint8_t* branch_stub_hit(uint32_t branch_idx, uint32_t target_idx)
{
    assert (branch_idx < num_branches);
    assert (target_idx < 2);
    branch_t *branch = &branch_entries[branch_idx];
    blockid_t target = branch->targets[target_idx];
    ctx_t* target_ctx = &branch->target_ctxs[target_idx];

    //fprintf(stderr, "\nstub hit, branch idx: %d, target idx: %d\n", branch_idx, target_idx);
    //fprintf(stderr, "cb->write_pos=%ld\n", cb->write_pos);
    //fprintf(stderr, "branch->end_pos=%d\n", branch->end_pos);

    // If either of the target blocks will be placed next
    if (cb->write_pos == branch->end_pos)
    {
        //fprintf(stderr, "target idx %d will be placed next\n", target_idx);
        branch->shape = (uint8_t)target_idx;

        // Rewrite the branch with the new, potentially more compact shape
        cb_set_pos(cb, branch->start_pos);
        branch->gen_fn(cb, branch->dst_addrs[0], branch->dst_addrs[1], branch->shape);
        assert (cb->write_pos <= branch->end_pos);
    }

    // Try to find a compiled version of this block
    version_t* p_version = find_block_version(target, target_ctx);

    // If this block hasn't yet been compiled
    if (!p_version)
    {
        p_version = gen_block_version(target, target_ctx);
    }

    // Add this branch to the list of incoming branches for the target
    add_incoming(p_version, branch_idx);

    // Update the branch target address
    uint8_t* dst_addr = cb_get_ptr(cb, p_version->start_pos);
    branch->dst_addrs[target_idx] = dst_addr;

    // Rewrite the branch with the new jump target address
    assert (branch->dst_addrs[0] != NULL);
    assert (branch->dst_addrs[1] != NULL);
    uint32_t cur_pos = cb->write_pos;
    cb_set_pos(cb, branch->start_pos);
    branch->gen_fn(cb, branch->dst_addrs[0], branch->dst_addrs[1], branch->shape);
    assert (cb->write_pos <= branch->end_pos);
    cb_set_pos(cb, cur_pos);

    // Return a pointer to the compiled block version
    return dst_addr;
}

// Get a version or stub corresponding to a branch target
// TODO: need incoming and target contexts
uint8_t* get_branch_target(
    blockid_t target,
    const ctx_t* ctx,
    codeblock_t* ocb,
    uint32_t branch_idx,
    uint32_t target_idx
)
{
    version_t* p_version = find_block_version(target, ctx);

    if (p_version)
    {
        // Add an incoming branch for this version
        add_incoming(p_version, branch_idx);

        return cb_get_ptr(cb, p_version->start_pos);
    }

    // Generate an outlined stub that will call
    // branch_stub_hit(uint32_t branch_idx, uint32_t target_idx)
    uint8_t* stub_addr = cb_get_ptr(ocb, ocb->write_pos);

    //fprintf(stderr, "REQUESTING STUB FOR IDX: %d\n", target.idx);

    // Save the ujit registers
    push(ocb, REG_CFP);
    push(ocb, REG_EC);
    push(ocb, REG_SP);
    push(ocb, REG_SP);

    mov(ocb, RDI, imm_opnd(branch_idx));
    mov(ocb, RSI, imm_opnd(target_idx));
    call_ptr(ocb, REG0, (void *)&branch_stub_hit);

    // Restore the ujit registers
    pop(ocb, REG_SP);
    pop(ocb, REG_SP);
    pop(ocb, REG_EC);
    pop(ocb, REG_CFP);

    // Jump to the address returned by the
    // branch_stub_hit call
    jmp_rm(ocb, RAX);

    return stub_addr;
}

void gen_branch(
    const ctx_t* src_ctx,
    blockid_t target0, 
    const ctx_t* ctx0,
    blockid_t target1, 
    const ctx_t* ctx1,
    branchgen_fn gen_fn
)
{
    assert (num_branches < MAX_BRANCHES);
    uint32_t branch_idx = num_branches;

    // Branch targets or stub adddresses (code pointers)
    uint8_t* dst_addr0;
    uint8_t* dst_addr1;

    // If there's only one branch target
    if (target1.iseq == NULL)
    {
        version_t* p_version = find_block_version(target0, ctx0);

        // If the version doesn't already exist
        if (!p_version)
        {
            // No need for a jump, compile the target block right here
            p_version = gen_block_version(target0, ctx0);
        }

        add_incoming(p_version, branch_idx);
        dst_addr0 = cb_get_ptr(cb, p_version->start_pos);
        dst_addr1 = NULL;
    }
    else
    {
        // Get the branch targets or stubs
        dst_addr0 = get_branch_target(target0, ctx0, ocb, branch_idx, 0);
        dst_addr1 = get_branch_target(target1, ctx1, ocb, branch_idx, 1);
    }

    // Call the branch generation function
    uint32_t start_pos = cb->write_pos;
    gen_fn(cb, dst_addr0, dst_addr1, SHAPE_DEFAULT);
    uint32_t end_pos = cb->write_pos;

    // Register this branch entry
    branch_t branch_entry = {
        start_pos,
        end_pos,
        *src_ctx,
        { target0, target1 },
        { *ctx0, *ctx1 },
        { dst_addr0, dst_addr1 },
        gen_fn,
        SHAPE_DEFAULT
    };

    branch_entries[num_branches] = branch_entry;
    num_branches++;
}

// Invalidate one specific block version
void invalidate(version_t* version)
{
    // All branches jumping to the block should be atomically patched with jumps going to a stub instead.

    // There can also be other blocks falling through to the invalidated block because they immediately precede it.
    // - If an incoming fall-through branch is too short to be patched, we may need to invalidate its block
    // - This may not be an issue in practice, because the block we go to could have space
    // - We can force any block that may need to be invalidated to have sufficient space to contain a jump to a stub

    // If the block is an entry point, it needs to be unmapped from its iseq
    // Unmap/remap anything at this iseq/idx

    // Optional: may want to recompile a new deoptimized entry point
    // Call continuation addresses on the stack can also be atomically replaced by jumps going to the stub.







}

int blockid_cmp(st_data_t arg0, st_data_t arg1)
{
    const blockid_t *block0 = (const blockid_t*)arg0;
    const blockid_t *block1 = (const blockid_t*)arg1;
    return block0->iseq == block1->iseq && block0->idx == block1->idx;
}

st_index_t blockid_hash(st_data_t arg)
{
    const blockid_t *blockid = (const blockid_t*)arg;
    st_index_t hash0 = st_numhash((st_data_t)blockid->iseq);
    st_index_t hash1 = st_numhash((st_data_t)(uint64_t)blockid->idx);

    // Use XOR to combine the hashes
    return hash0 ^ hash1;
}

static const struct st_hash_type hashtype_blockid = {
    blockid_cmp,
    blockid_hash,
};

void
ujit_init_core(void)
{
    // Initialize the version hash table
    version_tbl = st_init_table(&hashtype_blockid);
}