summaryrefslogtreecommitdiff
path: root/ujit_core.c
blob: 77c329970bb8f091399457c5e2f3d10b5d26256b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
#include "vm_core.h"
#include "vm_callinfo.h"
#include "builtin.h"
#include "insns.inc"
#include "insns_info.inc"
#include "ujit_asm.h"
#include "ujit_utils.h"
#include "ujit_iface.h"
#include "ujit_core.h"
#include "ujit_codegen.h"

// Maximum number of branch instructions we can track
#define MAX_BRANCHES 32768

// Table of block versions indexed by (iseq, index) tuples
st_table * version_tbl;

// Registered branch entries
branch_t branch_entries[MAX_BRANCHES];
uint32_t num_branches = 0;

/*
Get an operand for the adjusted stack pointer address
*/
x86opnd_t
ctx_sp_opnd(ctx_t* ctx, int32_t offset_bytes)
{
    int32_t offset = (ctx->stack_size) * 8 + offset_bytes;
    return mem_opnd(64, REG_SP, offset);
}

/*
Make space on the stack for N values
Return a pointer to the new stack top
*/
x86opnd_t
ctx_stack_push(ctx_t* ctx, size_t n)
{
    ctx->stack_size += n;

    // SP points just above the topmost value
    int32_t offset = (ctx->stack_size - 1) * 8;
    return mem_opnd(64, REG_SP, offset);
}

/*
Pop N values off the stack
Return a pointer to the stack top before the pop operation
*/
x86opnd_t
ctx_stack_pop(ctx_t* ctx, size_t n)
{
    // SP points just above the topmost value
    int32_t offset = (ctx->stack_size - 1) * 8;
    x86opnd_t top = mem_opnd(64, REG_SP, offset);

    ctx->stack_size -= n;

    return top;
}

x86opnd_t
ctx_stack_opnd(ctx_t* ctx, int32_t idx)
{
    // SP points just above the topmost value
    int32_t offset = (ctx->stack_size - 1 - idx) * 8;
    x86opnd_t opnd = mem_opnd(64, REG_SP, offset);

    return opnd;
}

// Retrieve a basic block version for an (iseq, idx) tuple
block_t* find_block_version(blockid_t blockid, const ctx_t* ctx)
{
    // If there exists a version for this block id
    st_data_t st_version;
    if (rb_st_lookup(version_tbl, (st_data_t)&blockid, &st_version)) {
        return (block_t*)st_version;
    }

    //
    // TODO: use the ctx parameter to search available versions
    //

    return NULL;
}
// Compile a new block version immediately
block_t* gen_block_version(blockid_t blockid, const ctx_t* ctx)
{
    // Allocate a version object
    block_t* p_block = malloc(sizeof(block_t));
    memcpy(&p_block->blockid, &blockid, sizeof(blockid_t));
    memcpy(&p_block->ctx, ctx, sizeof(ctx_t));
    p_block->incoming = NULL;
    p_block->num_incoming = 0;
    p_block->end_pos = 0;

    // The block starts at the current position
    p_block->start_pos = cb->write_pos;

    // Compile the block version
    ujit_gen_code(p_block);

    // The block may have been terminated in gen_branch
    if (p_block->end_pos == 0)
        p_block->end_pos = cb->write_pos;

    // Keep track of the new block version
    st_insert(version_tbl, (st_data_t)&p_block->blockid, (st_data_t)p_block);

    return p_block;
}

// Generate a block version that is an entry point inserted into an iseq
uint8_t* gen_entry_point(const rb_iseq_t *iseq, uint32_t insn_idx)
{
    // Allocate a version object
    block_t* p_block = malloc(sizeof(block_t));
    blockid_t blockid = { iseq, insn_idx };
    memcpy(&p_block->blockid, &blockid, sizeof(blockid_t));
    p_block->incoming = NULL;
    p_block->num_incoming = 0;
    p_block->end_pos = 0;

    // The entry context makes no assumptions about types
    ctx_t ctx = { 0 };
    memcpy(&p_block->ctx, &ctx, sizeof(ctx_t));

    // The block starts at the current position
    p_block->start_pos = cb->write_pos;

    // Compile the block version
    uint8_t* code_ptr = ujit_gen_entry(p_block);

    // The block may have been terminated in gen_branch
    if (p_block->end_pos == 0)
        p_block->end_pos = cb->write_pos;

    // If we couldn't generate any code
    if (!code_ptr)
    {
        free(p_block);
        return NULL;
    }

    // Keep track of the new block version
    st_insert(version_tbl, (st_data_t)&p_block->blockid, (st_data_t)p_block);

    return code_ptr;
}

// Add an incoming branch for a given block version
static void add_incoming(block_t* p_block, uint32_t branch_idx)
{
    // Add this branch to the list of incoming branches for the target
    uint32_t* new_list = malloc(sizeof(uint32_t) * p_block->num_incoming + 1);
    memcpy(new_list, p_block->incoming, p_block->num_incoming);
    new_list[p_block->num_incoming] = branch_idx;
    p_block->incoming = new_list;
    p_block->num_incoming += 1;
}

// Called by the generated code when a branch stub is executed
// Triggers compilation of branches and code patching
uint8_t* branch_stub_hit(uint32_t branch_idx, uint32_t target_idx)
{
    assert (branch_idx < num_branches);
    assert (target_idx < 2);
    branch_t *branch = &branch_entries[branch_idx];
    blockid_t target = branch->targets[target_idx];
    ctx_t* target_ctx = &branch->target_ctxs[target_idx];

    //fprintf(stderr, "\nstub hit, branch idx: %d, target idx: %d\n", branch_idx, target_idx);
    //fprintf(stderr, "cb->write_pos=%ld\n", cb->write_pos);
    //fprintf(stderr, "branch->end_pos=%d\n", branch->end_pos);

    // If either of the target blocks will be placed next
    if (cb->write_pos == branch->end_pos)
    {
        //fprintf(stderr, "target idx %d will be placed next\n", target_idx);
        branch->shape = (uint8_t)target_idx;

        // Rewrite the branch with the new, potentially more compact shape
        cb_set_pos(cb, branch->start_pos);
        branch->gen_fn(cb, branch->dst_addrs[0], branch->dst_addrs[1], branch->shape);
        assert (cb->write_pos <= branch->end_pos);
    }

    // Try to find a compiled version of this block
    block_t* p_block = find_block_version(target, target_ctx);

    // If this block hasn't yet been compiled
    if (!p_block)
    {
        p_block = gen_block_version(target, target_ctx);
    }

    // Add this branch to the list of incoming branches for the target
    add_incoming(p_block, branch_idx);

    // Update the branch target address
    uint8_t* dst_addr = cb_get_ptr(cb, p_block->start_pos);
    branch->dst_addrs[target_idx] = dst_addr;

    // Rewrite the branch with the new jump target address
    assert (branch->dst_addrs[0] != NULL);
    assert (branch->dst_addrs[1] != NULL);
    uint32_t cur_pos = cb->write_pos;
    cb_set_pos(cb, branch->start_pos);
    branch->gen_fn(cb, branch->dst_addrs[0], branch->dst_addrs[1], branch->shape);
    assert (cb->write_pos <= branch->end_pos);
    branch->end_pos = cb->write_pos;
    cb_set_pos(cb, cur_pos);

    // Return a pointer to the compiled block version
    return dst_addr;
}

// Get a version or stub corresponding to a branch target
// TODO: need incoming and target contexts
uint8_t* get_branch_target(
    blockid_t target,
    const ctx_t* ctx,
    uint32_t branch_idx,
    uint32_t target_idx
)
{
    block_t* p_block = find_block_version(target, ctx);

    if (p_block)
    {
        // Add an incoming branch for this version
        add_incoming(p_block, branch_idx);

        return cb_get_ptr(cb, p_block->start_pos);
    }

    // Generate an outlined stub that will call
    // branch_stub_hit(uint32_t branch_idx, uint32_t target_idx)
    uint8_t* stub_addr = cb_get_ptr(ocb, ocb->write_pos);

    //fprintf(stderr, "REQUESTING STUB FOR IDX: %d\n", target.idx);

    // Save the ujit registers
    push(ocb, REG_CFP);
    push(ocb, REG_EC);
    push(ocb, REG_SP);
    push(ocb, REG_SP);

    mov(ocb, RDI, imm_opnd(branch_idx));
    mov(ocb, RSI, imm_opnd(target_idx));
    call_ptr(ocb, REG0, (void *)&branch_stub_hit);

    // Restore the ujit registers
    pop(ocb, REG_SP);
    pop(ocb, REG_SP);
    pop(ocb, REG_EC);
    pop(ocb, REG_CFP);

    // Jump to the address returned by the
    // branch_stub_hit call
    jmp_rm(ocb, RAX);

    return stub_addr;
}

void gen_branch(
    block_t* src_version,
    const ctx_t* src_ctx,
    blockid_t target0, 
    const ctx_t* ctx0,
    blockid_t target1, 
    const ctx_t* ctx1,
    branchgen_fn gen_fn
)
{
    assert (num_branches < MAX_BRANCHES);
    uint32_t branch_idx = num_branches;

    // Branch targets or stub adddresses (code pointers)
    uint8_t* dst_addr0;
    uint8_t* dst_addr1;

    // If there's only one branch target
    if (target1.iseq == NULL)
    {
        block_t* p_block = find_block_version(target0, ctx0);

        // If the version doesn't already exist
        if (!p_block)
        {
            // No need for a jump, compile the target block right here
            p_block = gen_block_version(target0, ctx0);

            // The current version ends where the next version begins
            src_version->end_pos = p_block->start_pos;
        }

        add_incoming(p_block, branch_idx);
        dst_addr0 = cb_get_ptr(cb, p_block->start_pos);
        dst_addr1 = NULL;
    }
    else
    {
        // Get the branch targets or stubs
        dst_addr0 = get_branch_target(target0, ctx0, branch_idx, 0);
        dst_addr1 = get_branch_target(target1, ctx1, branch_idx, 1);
    }

    // Call the branch generation function
    uint32_t start_pos = cb->write_pos;
    gen_fn(cb, dst_addr0, dst_addr1, SHAPE_DEFAULT);
    uint32_t end_pos = cb->write_pos;

    // Register this branch entry
    branch_t branch_entry = {
        start_pos,
        end_pos,
        *src_ctx,
        { target0, target1 },
        { *ctx0, *ctx1 },
        { dst_addr0, dst_addr1 },
        gen_fn,
        SHAPE_DEFAULT
    };

    branch_entries[num_branches] = branch_entry;
    num_branches++;
}

// Invalidate one specific block version
void invalidate(block_t* block)
{
    // Remove the version object from the map so we can re-generate stubs
    st_delete(version_tbl, (st_data_t*)&block->blockid, NULL);

    uint8_t* code_ptr = cb_get_ptr(cb, block->start_pos);

    // For each incoming branch
    for (uint32_t i = 0; i < block->num_incoming; ++i)
    {
        uint32_t branch_idx = block->incoming[i];
        branch_t* branch = &branch_entries[branch_idx];
        uint32_t target_idx = (branch->dst_addrs[0] == code_ptr)? 0:1;

        // Create a stub for this branch target
        branch->dst_addrs[target_idx] = get_branch_target(
            block->blockid,
            &block->ctx,
            branch_idx,
            target_idx
        );

        // Check if the invalidated block immediately follows
        bool target_next = block->start_pos == branch->end_pos;

        if (target_next)
        {
            // Reset the branch shape
            branch->shape = SHAPE_DEFAULT;
        }

        // Rewrite the branch with the new jump target address
        assert (branch->dst_addrs[0] != NULL);
        assert (branch->dst_addrs[1] != NULL);
        uint32_t cur_pos = cb->write_pos;
        cb_set_pos(cb, branch->start_pos);
        branch->gen_fn(cb, branch->dst_addrs[0], branch->dst_addrs[1], branch->shape);
        branch->end_pos = cb->write_pos;
        cb_set_pos(cb, cur_pos);

        if (target_next && branch->end_pos > block->end_pos)
        {
            rb_bug("ujit invalidate rewrote branch past block end");
        }
    }

    // If the block is an entry point, it needs to be unmapped from its iseq
    const rb_iseq_t* iseq = block->blockid.iseq;
    uint32_t idx = block->blockid.idx;
    VALUE* entry_pc = &iseq->body->iseq_encoded[idx];
    int entry_opcode = opcode_at_pc(iseq, entry_pc);

    // TODO: unmap_addr2insn in ujit_iface.c? Maybe we can write a function to encompass this logic?
    // Should check how it's used in exit and side-exit
    const void * const *handler_table = rb_vm_get_insns_address_table();
    void* handler_addr = (void*)handler_table[entry_opcode];
    iseq->body->iseq_encoded[idx] = (VALUE)handler_addr;    

    //
    // Optional: may want to recompile a new deoptimized entry point
    //

    // TODO:
    // Call continuation addresses on the stack can also be atomically replaced by jumps going to the stub.
    // For now this isn't an issue

    // Free the block version object
    free(block);
}

int blockid_cmp(st_data_t arg0, st_data_t arg1)
{
    const blockid_t *block0 = (const blockid_t*)arg0;
    const blockid_t *block1 = (const blockid_t*)arg1;
    return block0->iseq == block1->iseq && block0->idx == block1->idx;
}

st_index_t blockid_hash(st_data_t arg)
{
    const blockid_t *blockid = (const blockid_t*)arg;
    st_index_t hash0 = st_numhash((st_data_t)blockid->iseq);
    st_index_t hash1 = st_numhash((st_data_t)(uint64_t)blockid->idx);

    // Use XOR to combine the hashes
    return hash0 ^ hash1;
}

static const struct st_hash_type hashtype_blockid = {
    blockid_cmp,
    blockid_hash,
};

void
ujit_init_core(void)
{
    // Initialize the version hash table
    version_tbl = st_init_table(&hashtype_blockid);
}